Почти полная переработка всего rag
- включение qdrant в контур - использование нормальной эмб-модели - векторизация текста - README и туча мелочей
This commit is contained in:
190
rag/rag.py
Normal file
190
rag/rag.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import argparse
|
||||
import os
|
||||
import hashlib
|
||||
import requests
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
class LocalRAGSystem:
|
||||
def __init__(self,
|
||||
md_folder: str = "input_md",
|
||||
ollama_url: str = "http://localhost:11434",
|
||||
qdrant_host: str = "localhost",
|
||||
qdrant_port: int = 6333,
|
||||
embed_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
||||
chat_model: str = "qwen2.5:3b"):
|
||||
self.md_folder = md_folder
|
||||
self.ollama_url = ollama_url
|
||||
self.qdrant_host = qdrant_host
|
||||
self.qdrant_port = qdrant_port
|
||||
self.embed_model = embed_model
|
||||
self.chat_model = chat_model
|
||||
self.emb_model = SentenceTransformer(embed_model)
|
||||
|
||||
def get_embedding(self, text: str):
|
||||
return self.emb_model.encode(text, show_progress_bar=False).tolist()
|
||||
|
||||
def search_qdrant(self, query: str, top_k: int = 6):
|
||||
query_vec = self.get_embedding(query)
|
||||
url = f"http://{self.qdrant_host}:{self.qdrant_port}/collections/rag_collection/points/search"
|
||||
payload = {
|
||||
"vector": query_vec,
|
||||
"top": top_k,
|
||||
"with_payload": True
|
||||
}
|
||||
resp = requests.post(url, json=payload)
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(f"> Ошибка qdrant: {resp.status_code} {resp.text}")
|
||||
results = resp.json().get("result", [])
|
||||
return results
|
||||
|
||||
def generate_answer(self, query: str, context_docs: list):
|
||||
query = query.strip()
|
||||
context = f""
|
||||
sources = f"\nИсточники:\n"
|
||||
for idx, doc in enumerate(context_docs, start=1):
|
||||
text = doc['payload'].get("text", "").strip()
|
||||
filename = doc['payload'].get("filename", None)
|
||||
url = doc['payload'].get("url", None)
|
||||
if filename:
|
||||
title = filename
|
||||
else:
|
||||
snippet = text[:40].replace("\n", " ").strip()
|
||||
if len(text) > 40:
|
||||
snippet += "..."
|
||||
title = snippet if snippet else "Empty text"
|
||||
if url is None:
|
||||
url = ""
|
||||
context = f"{context}\n--- Source [{idx}] ---\n{text}\n"
|
||||
sources = f"{sources}\n{idx}. {title}\n {url}"
|
||||
|
||||
if os.path.exists('sys_prompt.txt'):
|
||||
with open('sys_prompt.txt', 'r') as fp:
|
||||
prompt = fp.read().replace("{{context}}", context).replace("{{query}}", query)
|
||||
else:
|
||||
prompt = f"""
|
||||
Please provide an answer based solely on the provided sources.
|
||||
It is prohibited to generate an answer based on your pretrained data.
|
||||
If uncertain, ask the user for clarification.
|
||||
Respond in the same language as the user's query.
|
||||
If there are no sources in context, clearly state that.
|
||||
If the context is unreadable or of poor quality, inform the user and provide the best possible answer.
|
||||
When referencing information from a source, cite the appropriate source(s) using their corresponding numbers.
|
||||
Every answer should include at least one source citation.
|
||||
Only cite a source when you are explicitly referencing it.
|
||||
|
||||
If none of the sources are helpful, you should indicate that.
|
||||
For example:
|
||||
|
||||
--- Source 1 ---
|
||||
The sky is red in the evening and blue in the morning.
|
||||
|
||||
--- Source 2 ---
|
||||
Water is wet when the sky is red.
|
||||
|
||||
Query: When is water wet?
|
||||
Answer: Water will be wet when the sky is red [2], which occurs in the evening [1].
|
||||
|
||||
Now it's your turn. Below are several numbered sources of information:
|
||||
{context}
|
||||
|
||||
User query: {query}
|
||||
Your answer:
|
||||
"""
|
||||
|
||||
url = f"{self.ollama_url}/api/generate"
|
||||
body = {
|
||||
"model": self.chat_model,
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}
|
||||
resp = requests.post(url, json=body, timeout=600)
|
||||
if resp.status_code != 200:
|
||||
return f"Ошибка генерации ответа: {resp.status_code} {resp.text}"
|
||||
return resp.json().get("response", "").strip() + f"\n{sources}"
|
||||
|
||||
|
||||
def main():
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="RAG-система с использованием Ollama и Qdrant")
|
||||
parser.add_argument("--query", type=str, help="Запрос к RAG")
|
||||
parser.add_argument("--interactive", default=False, action=argparse.BooleanOptionalAction, help="Перейти в интерактивный режим диалога")
|
||||
parser.add_argument("--show-prompt", default=False, action=argparse.BooleanOptionalAction, help="Показать полный промпт перед обработкой запроса")
|
||||
parser.add_argument("--qdrant-host", default="localhost", help="Qdrant host")
|
||||
parser.add_argument("--qdrant-port", type=int, default=6333, help="Qdrant port")
|
||||
parser.add_argument("--ollama-url", default="http://localhost:11434", help="Ollama API URL")
|
||||
parser.add_argument("--emb-model", default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", help="Модель эмбеддинга")
|
||||
parser.add_argument("--chat-model", default="qwen2.5:3b", help="Модель генерации Ollama")
|
||||
parser.add_argument("--topk", type=int, default=6, help="Количество документов для поиска")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.query and not args.interactive:
|
||||
print("Ошибка: укажите запрос (--query) и/или используйте интерактивный режим (--interactive)")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Адрес ollama: {args.ollama_url}")
|
||||
print(f"Адрес qdrant: {args.qdrant_host}:{args.qdrant_port}")
|
||||
print(f"Модель эмбеддинга: {args.emb_model}")
|
||||
print(f"Модель чата: {args.chat_model}")
|
||||
print(f"Документов для поиска: {args.topk}")
|
||||
if os.path.exists('sys_prompt.txt'):
|
||||
print("Будет загружен системный промпт из sys_prompt.txt!")
|
||||
|
||||
if args.interactive:
|
||||
print("\nИНТЕРАКТИВНЫЙ РЕЖИМ")
|
||||
print("Можете вводить запрос (или 'exit' для выхода)\n")
|
||||
question = input(">>> ").strip()
|
||||
else:
|
||||
question = args.query.strip()
|
||||
|
||||
print("\nПервая инициализация моделей...")
|
||||
rag = LocalRAGSystem(
|
||||
ollama_url=args.ollama_url,
|
||||
qdrant_host=args.qdrant_host,
|
||||
qdrant_port=args.qdrant_port,
|
||||
embed_model=args.emb_model,
|
||||
chat_model=args.chat_model
|
||||
)
|
||||
|
||||
print(f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG")
|
||||
|
||||
while True:
|
||||
try:
|
||||
if not question or question == "":
|
||||
question = input(">>> ").strip()
|
||||
|
||||
if not question or question == "":
|
||||
continue
|
||||
|
||||
if question.lower() == "exit":
|
||||
print("\n*** Завершение работы")
|
||||
break
|
||||
|
||||
print("\nПоиск релевантных документов...")
|
||||
results = rag.search_qdrant(question, top_k=args.topk)
|
||||
if not results:
|
||||
print("Релевантные документы не найдены.")
|
||||
if args.interactive:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
print(f"Найдено {len(results)} релевантных документов")
|
||||
|
||||
if args.show_prompt:
|
||||
print("\nПолный системный промпт:\n")
|
||||
print(rag.prompt)
|
||||
|
||||
print("Генерация ответа...")
|
||||
answer = rag.generate_answer(question, results)
|
||||
print(f"\n<<< {answer}\n---------------------------------------------------\n")
|
||||
question = None
|
||||
except KeyboardInterrupt:
|
||||
print("\n*** Завершение работы")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Ошибка: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user