- включение qdrant в контур - использование нормальной эмб-модели - векторизация текста - README и туча мелочей
191 lines
8.3 KiB
Python
191 lines
8.3 KiB
Python
import argparse
|
||
import os
|
||
import hashlib
|
||
import requests
|
||
from sentence_transformers import SentenceTransformer
|
||
|
||
class LocalRAGSystem:
|
||
def __init__(self,
|
||
md_folder: str = "input_md",
|
||
ollama_url: str = "http://localhost:11434",
|
||
qdrant_host: str = "localhost",
|
||
qdrant_port: int = 6333,
|
||
embed_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
||
chat_model: str = "qwen2.5:3b"):
|
||
self.md_folder = md_folder
|
||
self.ollama_url = ollama_url
|
||
self.qdrant_host = qdrant_host
|
||
self.qdrant_port = qdrant_port
|
||
self.embed_model = embed_model
|
||
self.chat_model = chat_model
|
||
self.emb_model = SentenceTransformer(embed_model)
|
||
|
||
def get_embedding(self, text: str):
|
||
return self.emb_model.encode(text, show_progress_bar=False).tolist()
|
||
|
||
def search_qdrant(self, query: str, top_k: int = 6):
|
||
query_vec = self.get_embedding(query)
|
||
url = f"http://{self.qdrant_host}:{self.qdrant_port}/collections/rag_collection/points/search"
|
||
payload = {
|
||
"vector": query_vec,
|
||
"top": top_k,
|
||
"with_payload": True
|
||
}
|
||
resp = requests.post(url, json=payload)
|
||
if resp.status_code != 200:
|
||
raise RuntimeError(f"> Ошибка qdrant: {resp.status_code} {resp.text}")
|
||
results = resp.json().get("result", [])
|
||
return results
|
||
|
||
def generate_answer(self, query: str, context_docs: list):
|
||
query = query.strip()
|
||
context = f""
|
||
sources = f"\nИсточники:\n"
|
||
for idx, doc in enumerate(context_docs, start=1):
|
||
text = doc['payload'].get("text", "").strip()
|
||
filename = doc['payload'].get("filename", None)
|
||
url = doc['payload'].get("url", None)
|
||
if filename:
|
||
title = filename
|
||
else:
|
||
snippet = text[:40].replace("\n", " ").strip()
|
||
if len(text) > 40:
|
||
snippet += "..."
|
||
title = snippet if snippet else "Empty text"
|
||
if url is None:
|
||
url = ""
|
||
context = f"{context}\n--- Source [{idx}] ---\n{text}\n"
|
||
sources = f"{sources}\n{idx}. {title}\n {url}"
|
||
|
||
if os.path.exists('sys_prompt.txt'):
|
||
with open('sys_prompt.txt', 'r') as fp:
|
||
prompt = fp.read().replace("{{context}}", context).replace("{{query}}", query)
|
||
else:
|
||
prompt = f"""
|
||
Please provide an answer based solely on the provided sources.
|
||
It is prohibited to generate an answer based on your pretrained data.
|
||
If uncertain, ask the user for clarification.
|
||
Respond in the same language as the user's query.
|
||
If there are no sources in context, clearly state that.
|
||
If the context is unreadable or of poor quality, inform the user and provide the best possible answer.
|
||
When referencing information from a source, cite the appropriate source(s) using their corresponding numbers.
|
||
Every answer should include at least one source citation.
|
||
Only cite a source when you are explicitly referencing it.
|
||
|
||
If none of the sources are helpful, you should indicate that.
|
||
For example:
|
||
|
||
--- Source 1 ---
|
||
The sky is red in the evening and blue in the morning.
|
||
|
||
--- Source 2 ---
|
||
Water is wet when the sky is red.
|
||
|
||
Query: When is water wet?
|
||
Answer: Water will be wet when the sky is red [2], which occurs in the evening [1].
|
||
|
||
Now it's your turn. Below are several numbered sources of information:
|
||
{context}
|
||
|
||
User query: {query}
|
||
Your answer:
|
||
"""
|
||
|
||
url = f"{self.ollama_url}/api/generate"
|
||
body = {
|
||
"model": self.chat_model,
|
||
"prompt": prompt,
|
||
"stream": False
|
||
}
|
||
resp = requests.post(url, json=body, timeout=600)
|
||
if resp.status_code != 200:
|
||
return f"Ошибка генерации ответа: {resp.status_code} {resp.text}"
|
||
return resp.json().get("response", "").strip() + f"\n{sources}"
|
||
|
||
|
||
def main():
|
||
import sys
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="RAG-система с использованием Ollama и Qdrant")
|
||
parser.add_argument("--query", type=str, help="Запрос к RAG")
|
||
parser.add_argument("--interactive", default=False, action=argparse.BooleanOptionalAction, help="Перейти в интерактивный режим диалога")
|
||
parser.add_argument("--show-prompt", default=False, action=argparse.BooleanOptionalAction, help="Показать полный промпт перед обработкой запроса")
|
||
parser.add_argument("--qdrant-host", default="localhost", help="Qdrant host")
|
||
parser.add_argument("--qdrant-port", type=int, default=6333, help="Qdrant port")
|
||
parser.add_argument("--ollama-url", default="http://localhost:11434", help="Ollama API URL")
|
||
parser.add_argument("--emb-model", default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", help="Модель эмбеддинга")
|
||
parser.add_argument("--chat-model", default="qwen2.5:3b", help="Модель генерации Ollama")
|
||
parser.add_argument("--topk", type=int, default=6, help="Количество документов для поиска")
|
||
args = parser.parse_args()
|
||
|
||
if not args.query and not args.interactive:
|
||
print("Ошибка: укажите запрос (--query) и/или используйте интерактивный режим (--interactive)")
|
||
sys.exit(1)
|
||
|
||
print(f"Адрес ollama: {args.ollama_url}")
|
||
print(f"Адрес qdrant: {args.qdrant_host}:{args.qdrant_port}")
|
||
print(f"Модель эмбеддинга: {args.emb_model}")
|
||
print(f"Модель чата: {args.chat_model}")
|
||
print(f"Документов для поиска: {args.topk}")
|
||
if os.path.exists('sys_prompt.txt'):
|
||
print("Будет загружен системный промпт из sys_prompt.txt!")
|
||
|
||
if args.interactive:
|
||
print("\nИНТЕРАКТИВНЫЙ РЕЖИМ")
|
||
print("Можете вводить запрос (или 'exit' для выхода)\n")
|
||
question = input(">>> ").strip()
|
||
else:
|
||
question = args.query.strip()
|
||
|
||
print("\nПервая инициализация моделей...")
|
||
rag = LocalRAGSystem(
|
||
ollama_url=args.ollama_url,
|
||
qdrant_host=args.qdrant_host,
|
||
qdrant_port=args.qdrant_port,
|
||
embed_model=args.emb_model,
|
||
chat_model=args.chat_model
|
||
)
|
||
|
||
print(f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG")
|
||
|
||
while True:
|
||
try:
|
||
if not question or question == "":
|
||
question = input(">>> ").strip()
|
||
|
||
if not question or question == "":
|
||
continue
|
||
|
||
if question.lower() == "exit":
|
||
print("\n*** Завершение работы")
|
||
break
|
||
|
||
print("\nПоиск релевантных документов...")
|
||
results = rag.search_qdrant(question, top_k=args.topk)
|
||
if not results:
|
||
print("Релевантные документы не найдены.")
|
||
if args.interactive:
|
||
continue
|
||
else:
|
||
break
|
||
|
||
print(f"Найдено {len(results)} релевантных документов")
|
||
|
||
if args.show_prompt:
|
||
print("\nПолный системный промпт:\n")
|
||
print(rag.prompt)
|
||
|
||
print("Генерация ответа...")
|
||
answer = rag.generate_answer(question, results)
|
||
print(f"\n<<< {answer}\n---------------------------------------------------\n")
|
||
question = None
|
||
except KeyboardInterrupt:
|
||
print("\n*** Завершение работы")
|
||
break
|
||
except Exception as e:
|
||
print(f"Ошибка: {e}")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|