1
0

Доработка под sentence_transformers

This commit is contained in:
2025-08-23 11:55:24 +08:00
parent a01f903714
commit c6e498a0c8

View File

@@ -1,16 +1,14 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
RAG System for Local Ollama RAG System with Local Embeddings
Создает и использует RAG на основе markdown файлов для работы с локальной Ollama Создает и использует RAG на основе markdown файлов с локальными эмбеддингами
Скрипт сгенерирован claude-sonnet-4
""" """
import os import os
import json import json
import hashlib import hashlib
import pickle
from pathlib import Path from pathlib import Path
from typing import List, Dict, Tuple, Any from typing import List, Dict, Any
import requests import requests
import argparse import argparse
from datetime import datetime from datetime import datetime
@@ -20,12 +18,14 @@ try:
import numpy as np import numpy as np
import chromadb import chromadb
from chromadb.config import Settings from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
except ImportError: except ImportError:
print("Устанавливаем необходимые зависимости...") print("Устанавливаем необходимые зависимости...")
os.system("pip install chromadb numpy requests") os.system("pip install chromadb numpy sentence-transformers")
import numpy as np import numpy as np
import chromadb import chromadb
from chromadb.config import Settings from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
class LocalRAGSystem: class LocalRAGSystem:
@@ -34,7 +34,7 @@ class LocalRAGSystem:
db_path: str = "ready_rag", db_path: str = "ready_rag",
ollama_url: str = "http://localhost:11434", ollama_url: str = "http://localhost:11434",
embed_model: str = "nomic-embed-text", embed_model: str = "nomic-embed-text",
chat_model: str = "phi4-mini:3.8b"): chat_model: str = "phi:2.7b"):
self.md_folder = Path(md_folder) self.md_folder = Path(md_folder)
self.db_path = Path(db_path) self.db_path = Path(db_path)
@@ -42,15 +42,45 @@ class LocalRAGSystem:
self.embed_model = embed_model self.embed_model = embed_model
self.chat_model = chat_model self.chat_model = chat_model
# Инициализируем модель для эмбеддингов
print(f"Загрузка модели эмбеддингов: {embed_model}...")
self.embedding_model = SentenceTransformer(embed_model)
print(f"Модель {embed_model} загружена")
# Создаем папку для базы данных # Создаем папку для базы данных
self.db_path.mkdir(exist_ok=True) self.db_path.mkdir(exist_ok=True)
# Инициализируем ChromaDB # Инициализируем ChromaDB (удаляем старую коллекцию при необходимости)
self.chroma_client = chromadb.PersistentClient(path=str(self.db_path)) self.chroma_client = chromadb.PersistentClient(path=str(self.db_path))
self.collection = self.chroma_client.get_or_create_collection(
name="md_documents", # Получаем размерность текущей модели эмбеддингов
metadata={"description": "RAG collection for markdown documents"} embedding_dimension = self.embedding_model.get_sentence_embedding_dimension()
)
# Пытаемся получить коллекцию
try:
self.collection = self.chroma_client.get_collection(
name="md_documents"
)
# Проверяем совпадение размерности
if self.collection.metadata.get("embedding_dimension") != str(embedding_dimension):
print("Размерность эмбеддингов изменилась, пересоздаем коллекцию...")
self.chroma_client.delete_collection(name="md_documents")
self.collection = self.chroma_client.create_collection(
name="md_documents",
metadata={
"description": "RAG collection for markdown documents",
"embedding_dimension": str(embedding_dimension)
}
)
except:
# Коллекция не существует, создаем новую
self.collection = self.chroma_client.create_collection(
name="md_documents",
metadata={
"description": "RAG collection for markdown documents",
"embedding_dimension": str(embedding_dimension)
}
)
print(f"RAG система инициализирована:") print(f"RAG система инициализирована:")
print(f"- Папка с MD файлами: {self.md_folder}") print(f"- Папка с MD файлами: {self.md_folder}")
@@ -96,7 +126,7 @@ class LocalRAGSystem:
return None return None
def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 100) -> List[str]: def chunk_text(self, text: str, chunk_size: int = 1500, overlap: int = 100) -> List[str]:
"""Разбиваем текст на чанки с перекрытием""" """Разбиваем текст на чанки с перекрытием"""
chunks = [] chunks = []
start = 0 start = 0
@@ -143,7 +173,7 @@ class LocalRAGSystem:
if '# Краткое описание' in text: if '# Краткое описание' in text:
desc_match = re.search(r'# Краткое описание\n(.*?)(?=\n#|\n$)', text, re.DOTALL) desc_match = re.search(r'# Краткое описание\n(.*?)(?=\n#|\n$)', text, re.DOTALL)
if desc_match: if desc_match:
metadata['description'] = desc_match.group(1).strip()[:500] metadata['description'] = desc_match.group(1).strip()[:2000]
# Ищем требования # Ищем требования
if '# Требования' in text: if '# Требования' in text:
@@ -156,24 +186,13 @@ class LocalRAGSystem:
return metadata return metadata
def get_embedding(self, text: str) -> List[float]: def get_embedding(self, text: str) -> List[float]:
"""Получаем эмбеддинг через Ollama""" """Генерируем эмбеддинг локально с помощью SentenceTransformer"""
try: try:
response = requests.post( # Генерируем эмбеддинг
f"{self.ollama_url}/api/embeddings", embedding = self.embedding_model.encode(text, show_progress_bar=False)
json={ return embedding.tolist()
"model": self.embed_model,
"prompt": text
},
timeout=600
)
if response.status_code == 200:
return response.json()["embedding"]
else:
print(f"Ошибка получения эмбеддинга: {response.status_code}")
return None
except Exception as e: except Exception as e:
print(f"Ошибка при получении эмбеддинга: {e}") print(f"Ошибка при генерации эмбеддинга: {e}")
return None return None
def process_markdown_files(self) -> int: def process_markdown_files(self) -> int:
@@ -189,49 +208,19 @@ class LocalRAGSystem:
print(f"Найдено {len(md_files)} markdown файлов") print(f"Найдено {len(md_files)} markdown файлов")
# Проверяем подключение к Ollama # Проверяем подключение к Ollama (только для чат-модели)
if not self.check_ollama_connection(): if not self.check_ollama_connection():
print(f"Не удается подключиться к Ollama по адресу {self.ollama_url}") print(f"Не удается подключиться к Ollama по адресу {self.ollama_url}")
print("Убедитесь, что Ollama запущена и доступна") print("Убедитесь, что Ollama запущена и доступна для генерации ответов")
return 0 # Продолжаем работу, так как эмбеддинги локальные
print("Эмбеддинги будут генерироваться локально")
# Проверяем наличие модели эмбеддингов
available_models = self.get_ollama_models()
embed_model_name = self.find_model(self.embed_model, available_models)
if not embed_model_name:
print(f"Модель эмбеддингов {self.embed_model} не найдена в Ollama")
print(f"Доступные модели: {available_models}")
print(f"Загружаем модель {self.embed_model}...")
# Пытаемся загрузить модель
try:
response = requests.post(
f"{self.ollama_url}/api/pull",
json={"name": self.embed_model},
timeout=300
)
if response.status_code != 200:
print(f"Не удается загрузить модель {self.embed_model}")
return 0
else:
# После загрузки обновляем список моделей
available_models = self.get_ollama_models()
embed_model_name = self.find_model(self.embed_model, available_models)
except Exception as e:
print(f"Ошибка при загрузке модели: {e}")
return 0
# Обновляем имя модели эмбеддингов
if embed_model_name:
self.embed_model = embed_model_name
print(f"Используем модель эмбеддингов: {self.embed_model}")
total_files = len(md_files)
processed_count = 0 processed_count = 0
total_chunks = 0 total_chunks = 0
for md_file in md_files: for idx, md_file in enumerate(md_files, 1):
print(f"\nОбрабатываем: {md_file.name}") print(f"\n[{idx}/{total_files}] Обрабатываем файл: {md_file.name}")
try: try:
with open(md_file, 'r', encoding='utf-8') as f: with open(md_file, 'r', encoding='utf-8') as f:
@@ -285,10 +274,10 @@ class LocalRAGSystem:
total_chunks += 1 total_chunks += 1
processed_count += 1 processed_count += 1
print(f" Успешно обработан файл {md_file.name}") print(f" Успешно обработан")
except Exception as e: except Exception as e:
print(f" Ошибка при обработке {md_file.name}: {e}") print(f" Ошибка при обработке: {e}")
continue continue
print(f"\nОбработка завершена:") print(f"\nОбработка завершена:")
@@ -297,7 +286,7 @@ class LocalRAGSystem:
return processed_count return processed_count
def search(self, query: str, n_results: int = 5) -> List[Dict]: def search(self, query: str, n_results: int = 10) -> List[Dict]:
"""Поиск релевантных документов""" """Поиск релевантных документов"""
if self.collection.count() == 0: if self.collection.count() == 0:
return [] return []
@@ -331,9 +320,11 @@ class LocalRAGSystem:
context = "" context = ""
for i, doc in enumerate(context_docs, 1): for i, doc in enumerate(context_docs, 1):
context += f"\n--- Документ {i} (файл: {doc['metadata'].get('filename', 'unknown')}) ---\n" context += f"\n--- Документ {i} (файл: {doc['metadata'].get('filename', 'unknown')}) ---\n"
context += doc['document'][:1000] + ("..." if len(doc['document']) > 1000 else "") context += doc['document'][:2000] + ("..." if len(doc['document']) > 2000 else "")
context += "\n" context += "\n"
print(f"\nКонтекст: {context}")
# Формируем промпт # Формируем промпт
prompt = f"""На основе предоставленного контекста ответь на вопрос на русском языке. Если ответа нет в контексте, скажи об этом. prompt = f"""На основе предоставленного контекста ответь на вопрос на русском языке. Если ответа нет в контексте, скажи об этом.
@@ -367,18 +358,8 @@ class LocalRAGSystem:
"""Полный цикл RAG: поиск + генерация ответа""" """Полный цикл RAG: поиск + генерация ответа"""
print(f"\nВопрос: {question}") print(f"\nВопрос: {question}")
# Проверяем доступность моделей # Проверяем доступность чат-модели
available_models = self.get_ollama_models() available_models = self.get_ollama_models()
# Находим правильные имена моделей
embed_model_name = self.find_model(self.embed_model, available_models)
if not embed_model_name:
return {
"question": question,
"answer": f"Модель эмбеддингов {self.embed_model} не найдена в Ollama",
"sources": []
}
chat_model_name = self.find_model(self.chat_model, available_models) chat_model_name = self.find_model(self.chat_model, available_models)
if not chat_model_name: if not chat_model_name:
return { return {
@@ -387,8 +368,7 @@ class LocalRAGSystem:
"sources": [] "sources": []
} }
# Обновляем имена моделей # Обновляем имя чат-модели
self.embed_model = embed_model_name
self.chat_model = chat_model_name self.chat_model = chat_model_name
print("Ищем релевантные документы...") print("Ищем релевантные документы...")
@@ -442,8 +422,8 @@ def main():
default="interactive", help="Действие для выполнения") default="interactive", help="Действие для выполнения")
parser.add_argument("--question", type=str, help="Вопрос для поиска") parser.add_argument("--question", type=str, help="Вопрос для поиска")
parser.add_argument("--md-folder", default="output_md", help="Папка с markdown файлами") parser.add_argument("--md-folder", default="output_md", help="Папка с markdown файлами")
parser.add_argument("--embed-model", default="nomic-embed-text", help="Модель для эмбеддингов") parser.add_argument("--embed-model", default="all-MiniLM-L6-v2", help="Модель для эмбеддингов (SentenceTransformer)")
parser.add_argument("--chat-model", default="phi4-mini:3.8b", help="Модель для чата") parser.add_argument("--chat-model", default="gemma3n:e2b", help="Модель для чата")
parser.add_argument("--results", type=int, default=6, help="Количество результатов поиска") parser.add_argument("--results", type=int, default=6, help="Количество результатов поиска")
args = parser.parse_args() args = parser.parse_args()