510 lines
21 KiB
Python
510 lines
21 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
RAG System with Local Embeddings
|
||
Создает и использует RAG на основе markdown файлов с локальными эмбеддингами
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import hashlib
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any
|
||
import requests
|
||
import argparse
|
||
from datetime import datetime
|
||
import re
|
||
|
||
try:
|
||
import numpy as np
|
||
import chromadb
|
||
from chromadb.config import Settings
|
||
from sentence_transformers import SentenceTransformer
|
||
except ImportError:
|
||
print("Устанавливаем необходимые зависимости...")
|
||
os.system("pip install chromadb numpy sentence-transformers")
|
||
import numpy as np
|
||
import chromadb
|
||
from chromadb.config import Settings
|
||
from sentence_transformers import SentenceTransformer
|
||
|
||
|
||
class LocalRAGSystem:
|
||
def __init__(self,
|
||
md_folder: str = "output_md",
|
||
db_path: str = "ready_rag",
|
||
ollama_url: str = "http://localhost:11434",
|
||
embed_model: str = "nomic-embed-text",
|
||
chat_model: str = "phi:2.7b"):
|
||
|
||
self.md_folder = Path(md_folder)
|
||
self.db_path = Path(db_path)
|
||
self.ollama_url = ollama_url
|
||
self.embed_model = embed_model
|
||
self.chat_model = chat_model
|
||
|
||
# Инициализируем модель для эмбеддингов
|
||
print(f"Загрузка модели эмбеддингов: {embed_model}...")
|
||
self.embedding_model = SentenceTransformer(embed_model)
|
||
print(f"Модель {embed_model} загружена")
|
||
|
||
# Создаем папку для базы данных
|
||
self.db_path.mkdir(exist_ok=True)
|
||
|
||
# Инициализируем ChromaDB (удаляем старую коллекцию при необходимости)
|
||
self.chroma_client = chromadb.PersistentClient(path=str(self.db_path))
|
||
|
||
# Получаем размерность текущей модели эмбеддингов
|
||
embedding_dimension = self.embedding_model.get_sentence_embedding_dimension()
|
||
|
||
# Пытаемся получить коллекцию
|
||
try:
|
||
self.collection = self.chroma_client.get_collection(
|
||
name="md_documents"
|
||
)
|
||
# Проверяем совпадение размерности
|
||
if self.collection.metadata.get("embedding_dimension") != str(embedding_dimension):
|
||
print("Размерность эмбеддингов изменилась, пересоздаем коллекцию...")
|
||
self.chroma_client.delete_collection(name="md_documents")
|
||
self.collection = self.chroma_client.create_collection(
|
||
name="md_documents",
|
||
metadata={
|
||
"description": "RAG collection for markdown documents",
|
||
"embedding_dimension": str(embedding_dimension)
|
||
}
|
||
)
|
||
except:
|
||
# Коллекция не существует, создаем новую
|
||
self.collection = self.chroma_client.create_collection(
|
||
name="md_documents",
|
||
metadata={
|
||
"description": "RAG collection for markdown documents",
|
||
"embedding_dimension": str(embedding_dimension)
|
||
}
|
||
)
|
||
|
||
print(f"RAG система инициализирована:")
|
||
print(f"- Папка с MD файлами: {self.md_folder}")
|
||
print(f"- База данных: {self.db_path}")
|
||
print(f"- Ollama URL: {self.ollama_url}")
|
||
print(f"- Модель эмбеддингов: {self.embed_model}")
|
||
print(f"- Модель чата: {self.chat_model}")
|
||
|
||
def check_ollama_connection(self) -> bool:
|
||
"""Проверяем подключение к Ollama"""
|
||
try:
|
||
response = requests.get(f"{self.ollama_url}/api/tags")
|
||
return response.status_code == 200
|
||
except:
|
||
return False
|
||
|
||
def get_ollama_models(self) -> List[str]:
|
||
"""Получаем список доступных моделей в Ollama"""
|
||
try:
|
||
response = requests.get(f"{self.ollama_url}/api/tags")
|
||
if response.status_code == 200:
|
||
models = response.json().get('models', [])
|
||
return [model['name'] for model in models]
|
||
return []
|
||
except:
|
||
return []
|
||
|
||
def find_model(self, model_name: str, available_models: List[str]) -> str:
|
||
"""Найти модель по имени, учитывая суффикс :latest"""
|
||
# Сначала ищем точное совпадение
|
||
if model_name in available_models:
|
||
return model_name
|
||
|
||
# Затем ищем с суффиксом :latest
|
||
if f"{model_name}:latest" in available_models:
|
||
return f"{model_name}:latest"
|
||
|
||
# Если модель содержит :latest, пробуем без него
|
||
if model_name.endswith(":latest"):
|
||
base_name = model_name[:-7] # убираем ":latest"
|
||
if base_name in available_models:
|
||
return base_name
|
||
|
||
return None
|
||
|
||
def chunk_text(self, text: str, chunk_size: int = 1500, overlap: int = 100) -> List[str]:
|
||
"""Разбиваем текст на чанки с перекрытием"""
|
||
chunks = []
|
||
start = 0
|
||
|
||
while start < len(text):
|
||
end = start + chunk_size
|
||
chunk = text[start:end]
|
||
|
||
# Попытаемся разбить по предложениям
|
||
if end < len(text):
|
||
last_period = chunk.rfind('.')
|
||
last_newline = chunk.rfind('\n')
|
||
split_point = max(last_period, last_newline)
|
||
|
||
if split_point > start + chunk_size // 2:
|
||
chunk = text[start:split_point + 1]
|
||
end = split_point + 1
|
||
|
||
chunks.append(chunk.strip())
|
||
start = max(start + chunk_size - overlap, end - overlap)
|
||
|
||
if start >= len(text):
|
||
break
|
||
|
||
return [chunk for chunk in chunks if len(chunk.strip()) > 50]
|
||
|
||
def extract_metadata(self, text: str, filename: str) -> Dict[str, Any]:
|
||
"""Извлекаем метаданные из markdown файла"""
|
||
metadata = {
|
||
'filename': filename,
|
||
'length': len(text),
|
||
'created_at': datetime.now().isoformat()
|
||
}
|
||
|
||
# Ищем заголовки
|
||
headers = re.findall(r'^#{1,6}\s+(.+)$', text, re.MULTILINE)
|
||
if headers:
|
||
metadata['title'] = headers[0]
|
||
# Конвертируем список заголовков в строку (ChromaDB не принимает списки)
|
||
metadata['headers_text'] = '; '.join(headers[:5])
|
||
metadata['headers_count'] = len(headers)
|
||
|
||
# Ищем специальные секции
|
||
if '# Краткое описание' in text:
|
||
desc_match = re.search(r'# Краткое описание\n(.*?)(?=\n#|\n$)', text, re.DOTALL)
|
||
if desc_match:
|
||
metadata['description'] = desc_match.group(1).strip()[:2000]
|
||
|
||
# Ищем требования
|
||
if '# Требования' in text:
|
||
metadata['has_requirements'] = True
|
||
|
||
# Ищем нормативные документы
|
||
if '# Нормативная документация' in text:
|
||
metadata['has_regulations'] = True
|
||
|
||
return metadata
|
||
|
||
def get_embedding(self, text: str) -> List[float]:
|
||
"""Генерируем эмбеддинг локально с помощью SentenceTransformer"""
|
||
try:
|
||
# Генерируем эмбеддинг
|
||
embedding = self.embedding_model.encode(text, show_progress_bar=False)
|
||
return embedding.tolist()
|
||
except Exception as e:
|
||
print(f"Ошибка при генерации эмбеддинга: {e}")
|
||
return None
|
||
|
||
def process_markdown_files(self) -> int:
|
||
"""Обрабатываем все markdown файлы и добавляем их в векторную базу"""
|
||
if not self.md_folder.exists():
|
||
print(f"Папка {self.md_folder} не найдена!")
|
||
return 0
|
||
|
||
md_files = list(self.md_folder.glob("*.md"))
|
||
if not md_files:
|
||
print(f"Markdown файлы не найдены в {self.md_folder}")
|
||
return 0
|
||
|
||
print(f"Найдено {len(md_files)} markdown файлов")
|
||
|
||
# Проверяем подключение к Ollama (только для чат-модели)
|
||
if not self.check_ollama_connection():
|
||
print(f"Не удается подключиться к Ollama по адресу {self.ollama_url}")
|
||
print("Убедитесь, что Ollama запущена и доступна для генерации ответов")
|
||
# Продолжаем работу, так как эмбеддинги локальные
|
||
print("Эмбеддинги будут генерироваться локально")
|
||
|
||
total_files = len(md_files)
|
||
processed_count = 0
|
||
total_chunks = 0
|
||
|
||
for idx, md_file in enumerate(md_files, 1):
|
||
print(f"\n[{idx}/{total_files}] Обрабатываем файл: {md_file.name}")
|
||
|
||
try:
|
||
with open(md_file, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# Создаем чанки
|
||
chunks = self.chunk_text(content)
|
||
print(f" Создано чанков: {len(chunks)}")
|
||
|
||
# Извлекаем метаданные
|
||
base_metadata = self.extract_metadata(content, md_file.name)
|
||
|
||
# Обрабатываем каждый чанк
|
||
for i, chunk in enumerate(chunks):
|
||
# Создаем уникальный ID для чанка
|
||
chunk_id = hashlib.md5(f"{md_file.name}_{i}_{chunk[:100]}".encode()).hexdigest()
|
||
|
||
# Получаем эмбеддинг
|
||
embedding = self.get_embedding(chunk)
|
||
if embedding is None:
|
||
print(f" Пропускаем чанк {i} - не удалось получить эмбеддинг")
|
||
continue
|
||
|
||
# Подготавливаем метаданные для чанка
|
||
chunk_metadata = {}
|
||
# Копируем только допустимые типы метаданных
|
||
for key, value in base_metadata.items():
|
||
if isinstance(value, (str, int, float, bool, type(None))):
|
||
chunk_metadata[key] = value
|
||
elif isinstance(value, list):
|
||
# Конвертируем списки в строки
|
||
chunk_metadata[key] = '; '.join(map(str, value))
|
||
else:
|
||
# Конвертируем другие типы в строки
|
||
chunk_metadata[key] = str(value)
|
||
|
||
chunk_metadata.update({
|
||
'chunk_id': i,
|
||
'chunk_size': len(chunk),
|
||
'source_file': str(md_file)
|
||
})
|
||
|
||
# Добавляем в коллекцию
|
||
self.collection.add(
|
||
embeddings=[embedding],
|
||
documents=[chunk],
|
||
metadatas=[chunk_metadata],
|
||
ids=[chunk_id]
|
||
)
|
||
|
||
total_chunks += 1
|
||
|
||
processed_count += 1
|
||
print(f" Успешно обработан")
|
||
|
||
except Exception as e:
|
||
print(f" Ошибка при обработке: {e}")
|
||
continue
|
||
|
||
print(f"\nОбработка завершена:")
|
||
print(f"- Обработано файлов: {processed_count}")
|
||
print(f"- Создано чанков: {total_chunks}")
|
||
|
||
return processed_count
|
||
|
||
def search(self, query: str, n_results: int = 10) -> List[Dict]:
|
||
"""Поиск релевантных документов"""
|
||
if self.collection.count() == 0:
|
||
return []
|
||
|
||
# Получаем эмбеддинг для запроса
|
||
query_embedding = self.get_embedding(query)
|
||
if query_embedding is None:
|
||
print("Не удалось получить эмбеддинг для запроса")
|
||
return []
|
||
|
||
# Ищем похожие документы
|
||
results = self.collection.query(
|
||
query_embeddings=[query_embedding],
|
||
n_results=n_results
|
||
)
|
||
|
||
# Форматируем результаты
|
||
formatted_results = []
|
||
for i in range(len(results['documents'][0])):
|
||
formatted_results.append({
|
||
'document': results['documents'][0][i],
|
||
'metadata': results['metadatas'][0][i],
|
||
'distance': results['distances'][0][i] if 'distances' in results else None
|
||
})
|
||
|
||
return formatted_results
|
||
|
||
def generate_response(self, query: str, context_docs: List[Dict]) -> str:
|
||
"""Генерируем ответ используя контекст и Ollama"""
|
||
# Формируем контекст из найденных документов
|
||
context = ""
|
||
for i, doc in enumerate(context_docs, 1):
|
||
context += f"\n--- Документ {i} (файл: {doc['metadata'].get('filename', 'unknown')}) ---\n"
|
||
context += doc['document'][:2000] + ("..." if len(doc['document']) > 2000 else "")
|
||
context += "\n"
|
||
|
||
print(f"\nКонтекст: {context}")
|
||
|
||
# Формируем промпт
|
||
prompt = f"""На основе предоставленного контекста ответь на вопрос на русском языке. Если ответа нет в контексте, скажи об этом.
|
||
|
||
Контекст:
|
||
{context}
|
||
|
||
Вопрос: {query}
|
||
|
||
Ответ:"""
|
||
|
||
try:
|
||
response = requests.post(
|
||
f"{self.ollama_url}/api/generate",
|
||
json={
|
||
"model": self.chat_model,
|
||
"prompt": prompt,
|
||
"stream": False
|
||
},
|
||
timeout=600
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
return response.json()["response"]
|
||
else:
|
||
return f"Ошибка генерации ответа: {response.status_code}"
|
||
|
||
except Exception as e:
|
||
return f"Ошибка при обращении к Ollama: {e}"
|
||
|
||
def query(self, question: str, n_results: int = 5) -> Dict:
|
||
"""Полный цикл RAG: поиск + генерация ответа"""
|
||
print(f"\nВопрос: {question}")
|
||
|
||
# Проверяем доступность чат-модели
|
||
available_models = self.get_ollama_models()
|
||
chat_model_name = self.find_model(self.chat_model, available_models)
|
||
if not chat_model_name:
|
||
return {
|
||
"question": question,
|
||
"answer": f"Модель чата {self.chat_model} не найдена в Ollama",
|
||
"sources": []
|
||
}
|
||
|
||
# Обновляем имя чат-модели
|
||
self.chat_model = chat_model_name
|
||
|
||
print("Ищем релевантные документы...")
|
||
|
||
# Поиск документов
|
||
search_results = self.search(question, n_results)
|
||
|
||
if not search_results:
|
||
return {
|
||
"question": question,
|
||
"answer": "Не найдено релевантных документов для ответа на ваш вопрос.",
|
||
"sources": []
|
||
}
|
||
|
||
print(f"Найдено {len(search_results)} релевантных документов")
|
||
|
||
# Генерация ответа
|
||
print("Генерируем ответ...")
|
||
answer = self.generate_response(question, search_results)
|
||
|
||
# Формируем источники
|
||
sources = []
|
||
for doc in search_results:
|
||
sources.append({
|
||
"filename": doc['metadata'].get('filename', 'unknown'),
|
||
"title": doc['metadata'].get('title', ''),
|
||
"distance": doc.get('distance', 0)
|
||
})
|
||
|
||
return {
|
||
"question": question,
|
||
"answer": answer,
|
||
"sources": sources,
|
||
"context_docs": len(search_results)
|
||
}
|
||
|
||
def get_stats(self) -> Dict:
|
||
"""Получаем статистику RAG системы"""
|
||
return {
|
||
"total_documents": self.collection.count(),
|
||
"embedding_model": self.embed_model,
|
||
"chat_model": self.chat_model,
|
||
"database_path": str(self.db_path),
|
||
"source_folder": str(self.md_folder)
|
||
}
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="RAG System for Local Ollama")
|
||
parser.add_argument("--action", choices=["build", "query", "interactive", "stats"],
|
||
default="interactive", help="Действие для выполнения")
|
||
parser.add_argument("--question", type=str, help="Вопрос для поиска")
|
||
parser.add_argument("--md-folder", default="output_md", help="Папка с markdown файлами")
|
||
parser.add_argument("--embed-model", default="all-MiniLM-L6-v2", help="Модель для эмбеддингов (SentenceTransformer)")
|
||
parser.add_argument("--chat-model", default="gemma3n:e2b", help="Модель для чата")
|
||
parser.add_argument("--results", type=int, default=6, help="Количество результатов поиска")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Создаем RAG систему
|
||
rag = LocalRAGSystem(
|
||
md_folder=args.md_folder,
|
||
embed_model=args.embed_model,
|
||
chat_model=args.chat_model
|
||
)
|
||
|
||
if args.action == "build":
|
||
print("Строим RAG базу данных...")
|
||
count = rag.process_markdown_files()
|
||
print(f"Обработано {count} файлов")
|
||
|
||
elif args.action == "query":
|
||
if not args.question:
|
||
print("Укажите вопрос с помощью --question")
|
||
return
|
||
|
||
result = rag.query(args.question, args.results)
|
||
print(f"\nВопрос: {result['question']}")
|
||
print(f"\nОтвет:\n{result['answer']}")
|
||
print(f"\nИсточники:")
|
||
for source in result['sources']:
|
||
print(f"- {source['filename']}: {source['title']}")
|
||
|
||
elif args.action == "stats":
|
||
stats = rag.get_stats()
|
||
print("Статистика RAG системы:")
|
||
for key, value in stats.items():
|
||
print(f"- {key}: {value}")
|
||
|
||
elif args.action == "interactive":
|
||
print("\n=== Интерактивный режим RAG системы ===")
|
||
print("Введите 'exit' для выхода, 'stats' для статистики")
|
||
|
||
# Проверяем, есть ли данные в базе
|
||
if rag.collection.count() == 0:
|
||
print("\nБаза данных пуста. Строим RAG...")
|
||
count = rag.process_markdown_files()
|
||
if count == 0:
|
||
print("Не удалось построить базу данных. Завершение работы.")
|
||
return
|
||
|
||
stats = rag.get_stats()
|
||
print(f"\nДоступно документов: {stats['total_documents']}")
|
||
print(f"Модель эмбеддингов: {stats['embedding_model']}")
|
||
print(f"Модель чата: {stats['chat_model']}\n")
|
||
|
||
while True:
|
||
try:
|
||
question = input("\nВаш вопрос: ").strip()
|
||
|
||
if question.lower() == 'exit':
|
||
break
|
||
elif question.lower() == 'stats':
|
||
stats = rag.get_stats()
|
||
for key, value in stats.items():
|
||
print(f"- {key}: {value}")
|
||
continue
|
||
elif not question:
|
||
continue
|
||
|
||
result = rag.query(question, args.results)
|
||
print(f"\nОтвет:\n{result['answer']}")
|
||
|
||
print(f"\nИсточники ({len(result['sources'])}):")
|
||
for i, source in enumerate(result['sources'], 1):
|
||
print(f"{i}. {source['filename']}")
|
||
if source['title']:
|
||
print(f" Заголовок: {source['title']}")
|
||
|
||
except KeyboardInterrupt:
|
||
print("\nЗавершение работы...")
|
||
break
|
||
except Exception as e:
|
||
print(f"Ошибка: {e}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|