Доработка под sentence_transformers
This commit is contained in:
152
rag/3_rag.py
152
rag/3_rag.py
@@ -1,16 +1,14 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
RAG System for Local Ollama
|
RAG System with Local Embeddings
|
||||||
Создает и использует RAG на основе markdown файлов для работы с локальной Ollama
|
Создает и использует RAG на основе markdown файлов с локальными эмбеддингами
|
||||||
Скрипт сгенерирован claude-sonnet-4
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
import pickle
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Tuple, Any
|
from typing import List, Dict, Any
|
||||||
import requests
|
import requests
|
||||||
import argparse
|
import argparse
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -20,12 +18,14 @@ try:
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Устанавливаем необходимые зависимости...")
|
print("Устанавливаем необходимые зависимости...")
|
||||||
os.system("pip install chromadb numpy requests")
|
os.system("pip install chromadb numpy sentence-transformers")
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
|
||||||
class LocalRAGSystem:
|
class LocalRAGSystem:
|
||||||
@@ -34,7 +34,7 @@ class LocalRAGSystem:
|
|||||||
db_path: str = "ready_rag",
|
db_path: str = "ready_rag",
|
||||||
ollama_url: str = "http://localhost:11434",
|
ollama_url: str = "http://localhost:11434",
|
||||||
embed_model: str = "nomic-embed-text",
|
embed_model: str = "nomic-embed-text",
|
||||||
chat_model: str = "phi4-mini:3.8b"):
|
chat_model: str = "phi:2.7b"):
|
||||||
|
|
||||||
self.md_folder = Path(md_folder)
|
self.md_folder = Path(md_folder)
|
||||||
self.db_path = Path(db_path)
|
self.db_path = Path(db_path)
|
||||||
@@ -42,15 +42,45 @@ class LocalRAGSystem:
|
|||||||
self.embed_model = embed_model
|
self.embed_model = embed_model
|
||||||
self.chat_model = chat_model
|
self.chat_model = chat_model
|
||||||
|
|
||||||
|
# Инициализируем модель для эмбеддингов
|
||||||
|
print(f"Загрузка модели эмбеддингов: {embed_model}...")
|
||||||
|
self.embedding_model = SentenceTransformer(embed_model)
|
||||||
|
print(f"Модель {embed_model} загружена")
|
||||||
|
|
||||||
# Создаем папку для базы данных
|
# Создаем папку для базы данных
|
||||||
self.db_path.mkdir(exist_ok=True)
|
self.db_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
# Инициализируем ChromaDB
|
# Инициализируем ChromaDB (удаляем старую коллекцию при необходимости)
|
||||||
self.chroma_client = chromadb.PersistentClient(path=str(self.db_path))
|
self.chroma_client = chromadb.PersistentClient(path=str(self.db_path))
|
||||||
self.collection = self.chroma_client.get_or_create_collection(
|
|
||||||
name="md_documents",
|
# Получаем размерность текущей модели эмбеддингов
|
||||||
metadata={"description": "RAG collection for markdown documents"}
|
embedding_dimension = self.embedding_model.get_sentence_embedding_dimension()
|
||||||
)
|
|
||||||
|
# Пытаемся получить коллекцию
|
||||||
|
try:
|
||||||
|
self.collection = self.chroma_client.get_collection(
|
||||||
|
name="md_documents"
|
||||||
|
)
|
||||||
|
# Проверяем совпадение размерности
|
||||||
|
if self.collection.metadata.get("embedding_dimension") != str(embedding_dimension):
|
||||||
|
print("Размерность эмбеддингов изменилась, пересоздаем коллекцию...")
|
||||||
|
self.chroma_client.delete_collection(name="md_documents")
|
||||||
|
self.collection = self.chroma_client.create_collection(
|
||||||
|
name="md_documents",
|
||||||
|
metadata={
|
||||||
|
"description": "RAG collection for markdown documents",
|
||||||
|
"embedding_dimension": str(embedding_dimension)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
# Коллекция не существует, создаем новую
|
||||||
|
self.collection = self.chroma_client.create_collection(
|
||||||
|
name="md_documents",
|
||||||
|
metadata={
|
||||||
|
"description": "RAG collection for markdown documents",
|
||||||
|
"embedding_dimension": str(embedding_dimension)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
print(f"RAG система инициализирована:")
|
print(f"RAG система инициализирована:")
|
||||||
print(f"- Папка с MD файлами: {self.md_folder}")
|
print(f"- Папка с MD файлами: {self.md_folder}")
|
||||||
@@ -96,7 +126,7 @@ class LocalRAGSystem:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 100) -> List[str]:
|
def chunk_text(self, text: str, chunk_size: int = 1500, overlap: int = 100) -> List[str]:
|
||||||
"""Разбиваем текст на чанки с перекрытием"""
|
"""Разбиваем текст на чанки с перекрытием"""
|
||||||
chunks = []
|
chunks = []
|
||||||
start = 0
|
start = 0
|
||||||
@@ -143,7 +173,7 @@ class LocalRAGSystem:
|
|||||||
if '# Краткое описание' in text:
|
if '# Краткое описание' in text:
|
||||||
desc_match = re.search(r'# Краткое описание\n(.*?)(?=\n#|\n$)', text, re.DOTALL)
|
desc_match = re.search(r'# Краткое описание\n(.*?)(?=\n#|\n$)', text, re.DOTALL)
|
||||||
if desc_match:
|
if desc_match:
|
||||||
metadata['description'] = desc_match.group(1).strip()[:500]
|
metadata['description'] = desc_match.group(1).strip()[:2000]
|
||||||
|
|
||||||
# Ищем требования
|
# Ищем требования
|
||||||
if '# Требования' in text:
|
if '# Требования' in text:
|
||||||
@@ -156,24 +186,13 @@ class LocalRAGSystem:
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def get_embedding(self, text: str) -> List[float]:
|
def get_embedding(self, text: str) -> List[float]:
|
||||||
"""Получаем эмбеддинг через Ollama"""
|
"""Генерируем эмбеддинг локально с помощью SentenceTransformer"""
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
# Генерируем эмбеддинг
|
||||||
f"{self.ollama_url}/api/embeddings",
|
embedding = self.embedding_model.encode(text, show_progress_bar=False)
|
||||||
json={
|
return embedding.tolist()
|
||||||
"model": self.embed_model,
|
|
||||||
"prompt": text
|
|
||||||
},
|
|
||||||
timeout=600
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
return response.json()["embedding"]
|
|
||||||
else:
|
|
||||||
print(f"Ошибка получения эмбеддинга: {response.status_code}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Ошибка при получении эмбеддинга: {e}")
|
print(f"Ошибка при генерации эмбеддинга: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def process_markdown_files(self) -> int:
|
def process_markdown_files(self) -> int:
|
||||||
@@ -189,49 +208,19 @@ class LocalRAGSystem:
|
|||||||
|
|
||||||
print(f"Найдено {len(md_files)} markdown файлов")
|
print(f"Найдено {len(md_files)} markdown файлов")
|
||||||
|
|
||||||
# Проверяем подключение к Ollama
|
# Проверяем подключение к Ollama (только для чат-модели)
|
||||||
if not self.check_ollama_connection():
|
if not self.check_ollama_connection():
|
||||||
print(f"Не удается подключиться к Ollama по адресу {self.ollama_url}")
|
print(f"Не удается подключиться к Ollama по адресу {self.ollama_url}")
|
||||||
print("Убедитесь, что Ollama запущена и доступна")
|
print("Убедитесь, что Ollama запущена и доступна для генерации ответов")
|
||||||
return 0
|
# Продолжаем работу, так как эмбеддинги локальные
|
||||||
|
print("Эмбеддинги будут генерироваться локально")
|
||||||
# Проверяем наличие модели эмбеддингов
|
|
||||||
available_models = self.get_ollama_models()
|
|
||||||
embed_model_name = self.find_model(self.embed_model, available_models)
|
|
||||||
|
|
||||||
if not embed_model_name:
|
|
||||||
print(f"Модель эмбеддингов {self.embed_model} не найдена в Ollama")
|
|
||||||
print(f"Доступные модели: {available_models}")
|
|
||||||
print(f"Загружаем модель {self.embed_model}...")
|
|
||||||
|
|
||||||
# Пытаемся загрузить модель
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
f"{self.ollama_url}/api/pull",
|
|
||||||
json={"name": self.embed_model},
|
|
||||||
timeout=300
|
|
||||||
)
|
|
||||||
if response.status_code != 200:
|
|
||||||
print(f"Не удается загрузить модель {self.embed_model}")
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
# После загрузки обновляем список моделей
|
|
||||||
available_models = self.get_ollama_models()
|
|
||||||
embed_model_name = self.find_model(self.embed_model, available_models)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Ошибка при загрузке модели: {e}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Обновляем имя модели эмбеддингов
|
|
||||||
if embed_model_name:
|
|
||||||
self.embed_model = embed_model_name
|
|
||||||
print(f"Используем модель эмбеддингов: {self.embed_model}")
|
|
||||||
|
|
||||||
|
total_files = len(md_files)
|
||||||
processed_count = 0
|
processed_count = 0
|
||||||
total_chunks = 0
|
total_chunks = 0
|
||||||
|
|
||||||
for md_file in md_files:
|
for idx, md_file in enumerate(md_files, 1):
|
||||||
print(f"\nОбрабатываем: {md_file.name}")
|
print(f"\n[{idx}/{total_files}] Обрабатываем файл: {md_file.name}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(md_file, 'r', encoding='utf-8') as f:
|
with open(md_file, 'r', encoding='utf-8') as f:
|
||||||
@@ -285,10 +274,10 @@ class LocalRAGSystem:
|
|||||||
total_chunks += 1
|
total_chunks += 1
|
||||||
|
|
||||||
processed_count += 1
|
processed_count += 1
|
||||||
print(f" Успешно обработан файл {md_file.name}")
|
print(f" Успешно обработан")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" Ошибка при обработке {md_file.name}: {e}")
|
print(f" Ошибка при обработке: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"\nОбработка завершена:")
|
print(f"\nОбработка завершена:")
|
||||||
@@ -297,7 +286,7 @@ class LocalRAGSystem:
|
|||||||
|
|
||||||
return processed_count
|
return processed_count
|
||||||
|
|
||||||
def search(self, query: str, n_results: int = 5) -> List[Dict]:
|
def search(self, query: str, n_results: int = 10) -> List[Dict]:
|
||||||
"""Поиск релевантных документов"""
|
"""Поиск релевантных документов"""
|
||||||
if self.collection.count() == 0:
|
if self.collection.count() == 0:
|
||||||
return []
|
return []
|
||||||
@@ -331,9 +320,11 @@ class LocalRAGSystem:
|
|||||||
context = ""
|
context = ""
|
||||||
for i, doc in enumerate(context_docs, 1):
|
for i, doc in enumerate(context_docs, 1):
|
||||||
context += f"\n--- Документ {i} (файл: {doc['metadata'].get('filename', 'unknown')}) ---\n"
|
context += f"\n--- Документ {i} (файл: {doc['metadata'].get('filename', 'unknown')}) ---\n"
|
||||||
context += doc['document'][:1000] + ("..." if len(doc['document']) > 1000 else "")
|
context += doc['document'][:2000] + ("..." if len(doc['document']) > 2000 else "")
|
||||||
context += "\n"
|
context += "\n"
|
||||||
|
|
||||||
|
print(f"\nКонтекст: {context}")
|
||||||
|
|
||||||
# Формируем промпт
|
# Формируем промпт
|
||||||
prompt = f"""На основе предоставленного контекста ответь на вопрос на русском языке. Если ответа нет в контексте, скажи об этом.
|
prompt = f"""На основе предоставленного контекста ответь на вопрос на русском языке. Если ответа нет в контексте, скажи об этом.
|
||||||
|
|
||||||
@@ -367,18 +358,8 @@ class LocalRAGSystem:
|
|||||||
"""Полный цикл RAG: поиск + генерация ответа"""
|
"""Полный цикл RAG: поиск + генерация ответа"""
|
||||||
print(f"\nВопрос: {question}")
|
print(f"\nВопрос: {question}")
|
||||||
|
|
||||||
# Проверяем доступность моделей
|
# Проверяем доступность чат-модели
|
||||||
available_models = self.get_ollama_models()
|
available_models = self.get_ollama_models()
|
||||||
|
|
||||||
# Находим правильные имена моделей
|
|
||||||
embed_model_name = self.find_model(self.embed_model, available_models)
|
|
||||||
if not embed_model_name:
|
|
||||||
return {
|
|
||||||
"question": question,
|
|
||||||
"answer": f"Модель эмбеддингов {self.embed_model} не найдена в Ollama",
|
|
||||||
"sources": []
|
|
||||||
}
|
|
||||||
|
|
||||||
chat_model_name = self.find_model(self.chat_model, available_models)
|
chat_model_name = self.find_model(self.chat_model, available_models)
|
||||||
if not chat_model_name:
|
if not chat_model_name:
|
||||||
return {
|
return {
|
||||||
@@ -387,8 +368,7 @@ class LocalRAGSystem:
|
|||||||
"sources": []
|
"sources": []
|
||||||
}
|
}
|
||||||
|
|
||||||
# Обновляем имена моделей
|
# Обновляем имя чат-модели
|
||||||
self.embed_model = embed_model_name
|
|
||||||
self.chat_model = chat_model_name
|
self.chat_model = chat_model_name
|
||||||
|
|
||||||
print("Ищем релевантные документы...")
|
print("Ищем релевантные документы...")
|
||||||
@@ -442,8 +422,8 @@ def main():
|
|||||||
default="interactive", help="Действие для выполнения")
|
default="interactive", help="Действие для выполнения")
|
||||||
parser.add_argument("--question", type=str, help="Вопрос для поиска")
|
parser.add_argument("--question", type=str, help="Вопрос для поиска")
|
||||||
parser.add_argument("--md-folder", default="output_md", help="Папка с markdown файлами")
|
parser.add_argument("--md-folder", default="output_md", help="Папка с markdown файлами")
|
||||||
parser.add_argument("--embed-model", default="nomic-embed-text", help="Модель для эмбеддингов")
|
parser.add_argument("--embed-model", default="all-MiniLM-L6-v2", help="Модель для эмбеддингов (SentenceTransformer)")
|
||||||
parser.add_argument("--chat-model", default="phi4-mini:3.8b", help="Модель для чата")
|
parser.add_argument("--chat-model", default="gemma3n:e2b", help="Модель для чата")
|
||||||
parser.add_argument("--results", type=int, default=6, help="Количество результатов поиска")
|
parser.add_argument("--results", type=int, default=6, help="Количество результатов поиска")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
Reference in New Issue
Block a user