508 lines
22 KiB
Python
508 lines
22 KiB
Python
import os
|
||
import requests
|
||
import json
|
||
import time
|
||
import sys
|
||
from qdrant_client import QdrantClient
|
||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||
|
||
DEFAULT_CHAT_MODEL = "openchat:7b"
|
||
DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||
DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||
DEFAULT_MD_FOLDER = "data"
|
||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||
DEFAULT_QDRANT_HOST = "localhost"
|
||
DEFAULT_QDRANT_PORT = 6333
|
||
DEFAULT_QDRANT_COLLECTION = "rag"
|
||
DEFAULT_TOP_K = 30
|
||
DEFAULT_USE_RANK = False
|
||
DEFAULT_TOP_N = 8
|
||
DEFAULT_VERBOSE = False
|
||
DEFAULT_SHOW_STATS = False
|
||
DEFAULT_STREAM = False
|
||
DEFAULT_INTERACTIVE = False
|
||
DEFAULT_SHOW_PROMPT = False
|
||
DEFAULT_MIN_RANK_SCORE = 0
|
||
|
||
class RagSystem:
|
||
def __init__(self,
|
||
ollama_url: str = DEFAULT_OLLAMA_URL,
|
||
qdrant_host: str = DEFAULT_QDRANT_HOST,
|
||
qdrant_port: int = DEFAULT_QDRANT_PORT,
|
||
embed_model: str = DEFAULT_EMBED_MODEL,
|
||
rank_model: str = DEFAULT_RANK_MODEL,
|
||
use_rank: bool = DEFAULT_USE_RANK,
|
||
chat_model: str = DEFAULT_CHAT_MODEL):
|
||
self.ollama_url = ollama_url
|
||
self.qdrant_host = qdrant_host
|
||
self.qdrant_port = qdrant_port
|
||
self.chat_model = chat_model
|
||
self.emb_model = SentenceTransformer(embed_model)
|
||
self.qdrant = QdrantClient(host=args.qdrant_host, port=args.qdrant_port)
|
||
self.use_rank = use_rank
|
||
if self.use_rank:
|
||
self.rank_model = CrossEncoder(rank_model)
|
||
self.conversation_history = []
|
||
|
||
def check_chat_model(self):
|
||
response = requests.get(f"{self.ollama_url}/api/tags")
|
||
if response.status_code != 200:
|
||
return False
|
||
for model in response.json().get("models", []):
|
||
if model["name"] == self.chat_model:
|
||
return True
|
||
return False
|
||
|
||
def install_chat_model(self, model: str = DEFAULT_CHAT_MODEL):
|
||
try:
|
||
response = requests.post(f"{self.ollama_url}/api/pull", json={"model": model})
|
||
if response.status_code == 200:
|
||
print(f"Модель {self.chat_model} установлена успешно")
|
||
else:
|
||
print(f"Ошибка установки модели: {response.text}")
|
||
except Exception as e:
|
||
print(f"Ошибка проверки модели: {str(e)}")
|
||
|
||
def load_chat_model(self):
|
||
requests.post(f"{self.ollama_url}/api/generate", json={"model": self.chat_model}, timeout=600)
|
||
|
||
def search_qdrant(self, query: str, doc_count: int = DEFAULT_TOP_K, collection_name = DEFAULT_QDRANT_COLLECTION):
|
||
query_vec = self.emb_model.encode(query, show_progress_bar=False).tolist()
|
||
results = self.qdrant.query_points(
|
||
collection_name=collection_name,
|
||
query=query_vec,
|
||
limit=doc_count,
|
||
# score_threshold=0.5,
|
||
)
|
||
docs = []
|
||
for point in results.points:
|
||
docs.append({
|
||
"payload": point.payload,
|
||
"score": point.score,
|
||
})
|
||
return docs
|
||
|
||
def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE):
|
||
if not self.use_rank:
|
||
return documents
|
||
|
||
pairs = [[query, doc["payload"]["text"]] for doc in documents]
|
||
scores = self.rank_model.predict(pairs)
|
||
|
||
ranked_docs = []
|
||
for i, doc in enumerate(documents):
|
||
score = float(scores[i])
|
||
doc["rank_score"] = score
|
||
if score >= min_score:
|
||
ranked_docs.append(doc)
|
||
|
||
ranked_docs.sort(key=lambda x: x['rank_score'], reverse=True)
|
||
return ranked_docs[:top_n]
|
||
|
||
def generate_answer(self, sys_prompt: str, user_prompt: str):
|
||
url = f"{self.ollama_url}/api/generate"
|
||
body = {
|
||
"model": self.chat_model,
|
||
"system": sys_prompt,
|
||
"prompt": user_prompt,
|
||
"stream": False,
|
||
"options": {
|
||
"temperature": 0.5,
|
||
# "top_p": 0.2,
|
||
},
|
||
}
|
||
|
||
response = requests.post(url, json=body, timeout=900)
|
||
if response.status_code != 200:
|
||
return f"Ошибка генерации ответа: {response.status_code} {response.text}"
|
||
self.response = response.json()
|
||
return self.response["response"]
|
||
|
||
def generate_answer_stream(self, sys_prompt: str, user_prompt: str):
|
||
url = f"{self.ollama_url}/api/generate"
|
||
body = {
|
||
"model": self.chat_model,
|
||
"system": sys_prompt,
|
||
"prompt": user_prompt,
|
||
"stream": True,
|
||
"options": {
|
||
"temperature": 0.5,
|
||
# "top_p": 0.2,
|
||
},
|
||
}
|
||
resp = requests.post(url, json=body, stream=True, timeout=900)
|
||
if resp.status_code != 200:
|
||
raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}")
|
||
|
||
answer = ""
|
||
self.response = None
|
||
for chunk in resp.iter_lines():
|
||
if chunk:
|
||
try:
|
||
decoded_chunk = chunk.decode('utf-8')
|
||
data = json.loads(decoded_chunk)
|
||
if "response" in data:
|
||
yield data["response"]
|
||
answer += data["response"]
|
||
if "done" in data and data["done"] is True:
|
||
self.response = data
|
||
break
|
||
elif "error" in data:
|
||
answer += f" | Ошибка стриминга ответа: {data['error']}"
|
||
break
|
||
except json.JSONDecodeError as e:
|
||
answer += f" | Ошибка конвертации чанка: {chunk.decode('utf-8')} - {e}"
|
||
except Exception as e:
|
||
answer += f" | Ошибка обработки чанка: {e}"
|
||
|
||
def get_prompt_eval_count(self):
|
||
if not self.response:
|
||
return 0
|
||
return self.response["prompt_eval_count"]
|
||
|
||
def get_prompt_eval_duration(self):
|
||
if not self.response:
|
||
return 0
|
||
return self.response["prompt_eval_duration"] / (10 ** 9)
|
||
|
||
def get_eval_count(self):
|
||
if not self.response:
|
||
return 0
|
||
return self.response["eval_count"]
|
||
|
||
def get_eval_duration(self):
|
||
if not self.response:
|
||
return 0
|
||
return self.response["eval_duration"] / (10 ** 9)
|
||
|
||
def get_total_duration(self):
|
||
if not self.response:
|
||
return 0
|
||
return self.response["total_duration"] / (10 ** 9)
|
||
|
||
def get_tps(self):
|
||
eval_count = self.get_eval_count()
|
||
eval_duration = self.get_eval_duration()
|
||
if eval_count == 0 or eval_duration == 0:
|
||
return 0
|
||
return eval_count / eval_duration
|
||
|
||
class App:
|
||
def __init__(
|
||
self,
|
||
args: list = []
|
||
):
|
||
if not args.query and not args.interactive:
|
||
print("Ошибка: укажите запрос (--query) и/или используйте интерактивный режим (--interactive)")
|
||
sys.exit(1)
|
||
|
||
self.args = args
|
||
self.print_v(text=f"Включить интерактивный режим диалога: {args.interactive}")
|
||
self.print_v(text=f"Включить потоковый вывод: {args.stream}")
|
||
if self.is_custom_sys_prompt():
|
||
self.print_v(text=f"Системный промпт: {args.sys_prompt}")
|
||
else:
|
||
self.print_v(text=f"Системный промпт: по умолчанию")
|
||
self.print_v(text=f"Показать сист. промпт перед запросом: {args.show_prompt}")
|
||
self.print_v(text=f"Выводить служебные сообщения: {args.verbose}")
|
||
self.print_v(text=f"Выводить статистику об ответе: {args.show_stats}")
|
||
self.print_v(text=f"Адрес хоста Qdrant: {args.qdrant_host}")
|
||
self.print_v(text=f"Номер порта Qdrant: {args.qdrant_port}")
|
||
self.print_v(text=f"Название коллекции для поиска документов: {args.qdrant_collection}")
|
||
self.print_v(text=f"Ollama API URL: {args.ollama_url}")
|
||
self.print_v(text=f"Модель генерации Ollama: {args.chat_model}")
|
||
self.print_v(text=f"Модель эмбеддинга: {args.emb_model}")
|
||
self.print_v(text=f"Количество документов для поиска: {args.topk}")
|
||
self.print_v(text=f"Включить ранжирование: {args.use_rank}")
|
||
self.print_v(text=f"Модель ранжирования: {args.rank_model}")
|
||
self.print_v(text=f"Количество документов после ранжирования: {args.topn}")
|
||
self.init_rag()
|
||
|
||
def print_v(self, text: str = "\n"):
|
||
if self.args.verbose:
|
||
print(f"{text}")
|
||
|
||
def init_rag(self):
|
||
self.print_v(text="\nИнициализация моделей...")
|
||
self.rag = RagSystem(
|
||
ollama_url = self.args.ollama_url,
|
||
qdrant_host = self.args.qdrant_host,
|
||
qdrant_port = self.args.qdrant_port,
|
||
embed_model = self.args.emb_model,
|
||
rank_model = self.args.rank_model,
|
||
use_rank = self.args.use_rank,
|
||
chat_model = self.args.chat_model
|
||
)
|
||
if not self.rag.check_chat_model():
|
||
print(f"Установка модели {self.args.chat_model} ...")
|
||
self.rag.install_chat_model(self.args.chat_model)
|
||
self.rag.load_chat_model()
|
||
self.print_v(text=f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG")
|
||
|
||
def init_query(self):
|
||
self.query = None
|
||
if args.interactive:
|
||
self.print_v(text="\nИНТЕРАКТИВНЫЙ РЕЖИМ")
|
||
self.print_v(text="Можете вводить запрос (или 'exit' для выхода)\n")
|
||
|
||
if self.args.query:
|
||
self.query = self.args.query.strip()
|
||
print(f">>> {self.query}")
|
||
elif args.interactive:
|
||
self.query = input(">>> ").strip()
|
||
|
||
def process_help(self):
|
||
print("<<< Команды интерактивного режима:")
|
||
print("save -- сохранить диалог в файл")
|
||
print("stats -- статистика последнего ответа")
|
||
print("exit -- выход\n")
|
||
self.query = None
|
||
self.args.query = None
|
||
|
||
def process_save(self):
|
||
import datetime
|
||
timestamp = int(time.time())
|
||
dt = datetime.datetime.fromtimestamp(timestamp).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||
filename = f"chats/chat-{timestamp}-{self.args.chat_model}.md"
|
||
|
||
markdown_content = f"# История диалога от {dt}\n\n"
|
||
markdown_content += f"## Параметры диалога\n"
|
||
markdown_content += f"```\nargs = {self.args}\n```\n"
|
||
markdown_content += f"```\nemb_model = {self.rag.emb_model}\n```\n"
|
||
markdown_content += f"```\nrank_model = {self.rag.rank_model}\n```\n"
|
||
|
||
for entry in self.rag.conversation_history:
|
||
if entry['role'] == 'user':
|
||
markdown_content += f"## Пользователь\n\n"
|
||
elif entry['role'] == 'assistant':
|
||
markdown_content += f"## Модель\n\n"
|
||
docs = self.rag.prepare_ctx_sources(entry['docs']).replace("```", "")
|
||
markdown_content += f"```\n{docs}\n```\n\n"
|
||
markdown_content += f"{entry['content']}\n\n"
|
||
|
||
os.makedirs('chats', exist_ok=True)
|
||
with open(filename, 'w') as fp:
|
||
fp.write(markdown_content)
|
||
|
||
print(f"<<< Диалог сохранён в файл: {filename}\n")
|
||
self.query = None
|
||
|
||
def find_docs(self, query: str, top_k: int, collection_name: str):
|
||
self.print_v(text="\nПоиск документов...")
|
||
context_docs = self.rag.search_qdrant(query, top_k, collection_name)
|
||
self.print_v(text=f"Найдено {len(context_docs)} документов")
|
||
return context_docs
|
||
|
||
def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE):
|
||
self.print_v(text="\nРанжирование документов...")
|
||
ranked_docs = self.rag.rank_documents(self.query, docs, top_n, min_score)
|
||
self.print_v(text=f"После ранжирования осталось {len(ranked_docs)} документов")
|
||
return ranked_docs
|
||
|
||
def prepare_ctx_sources(self, docs: list):
|
||
sources = ""
|
||
for idx, doc in enumerate(docs, start=1):
|
||
text = doc['payload'].get("text", "").strip()
|
||
sources = f"{sources}\n<source id=\"{idx}\">\n{text}\n</source>\n"
|
||
return sources
|
||
|
||
def prepare_cli_sources(self, docs: list):
|
||
sources = "\nИсточники:\n"
|
||
for idx, doc in enumerate(docs, start=1):
|
||
title = doc['payload'].get("filename", None)
|
||
url = doc['payload'].get("url", None)
|
||
date = doc['payload'].get("date", None)
|
||
version = doc['payload'].get("version", None)
|
||
author = doc['payload'].get("author", None)
|
||
|
||
if url is None:
|
||
url = "(нет веб-ссылки)"
|
||
if date is None:
|
||
date = "(неизвестно)"
|
||
if version is None:
|
||
version = "0"
|
||
if author is None:
|
||
author = "(неизвестен)"
|
||
|
||
sources += f"{idx}. {title}\n"
|
||
sources += f" {url}\n"
|
||
sources += f" Версия {version} от {author}, актуальная на {date}\n"
|
||
if doc['rank_score']:
|
||
sources += f" score = {doc['score']} | rank_score = {doc['rank_score']}\n"
|
||
else:
|
||
sources += f" score = {doc['score']}\n"
|
||
return sources
|
||
|
||
def prepare_sys_prompt(self, query: str, docs: list):
|
||
if self.is_custom_sys_prompt():
|
||
with open(self.args.sys_prompt, 'r') as fp:
|
||
prompt_tpl = fp.read()
|
||
else:
|
||
prompt_tpl = """You are a helpful assistant that can answer questions based on the provided context.
|
||
Your user is the person asking the source-related question.
|
||
Your job is to answer the question based on the context alone.
|
||
If the context doesn't provide much information, answer "I don't know."
|
||
Adhere to this in all languages.
|
||
|
||
Context:
|
||
|
||
-----------------------------------------
|
||
{{sources}}
|
||
-----------------------------------------
|
||
"""
|
||
|
||
sources = self.prepare_ctx_sources(docs)
|
||
return prompt_tpl.replace("{{sources}}", sources).replace("{{query}}", query)
|
||
|
||
def show_prompt(self, sys_prompt: str):
|
||
print("\n================ Системный промпт ==================")
|
||
print(f"{sys_prompt}\n============ Конец системного промпта ==============\n")
|
||
|
||
def process_query(self, sys_prompt: str, user_prompt: str, streaming: bool = DEFAULT_STREAM):
|
||
answer = ""
|
||
# try:
|
||
if streaming:
|
||
self.print_v(text="\nГенерация потокового ответа (^C для остановки)...\n")
|
||
print(f"<<< ", end='', flush=True)
|
||
for token in self.rag.generate_answer_stream(sys_prompt, user_prompt):
|
||
answer += token
|
||
print(token, end='', flush=True)
|
||
else:
|
||
self.print_v(text="\nГенерация ответа (^C для остановки)...\n")
|
||
answer = self.rag.generate_answer(sys_prompt, user_prompt)
|
||
print(f"<<< {answer}\n")
|
||
# except RuntimeError as e:
|
||
# answer = str(e)
|
||
|
||
print(f"\n===================================================")
|
||
return answer
|
||
|
||
def is_custom_sys_prompt(self):
|
||
return self.args.sys_prompt and os.path.exists(self.args.sys_prompt)
|
||
|
||
def print_stats(self):
|
||
print(f"* Time: {self.rag.get_total_duration()}s")
|
||
print(f"* TPS: {self.rag.get_tps()}")
|
||
print(f"* PEC: {self.rag.get_prompt_eval_count()}")
|
||
print(f"* PED: {self.rag.get_prompt_eval_duration()}s")
|
||
print(f"* EC: {self.rag.get_eval_count()}")
|
||
print(f"* ED: {self.rag.get_eval_duration()}s\n")
|
||
self.query = None
|
||
self.args.query = None
|
||
|
||
def process(self):
|
||
while True:
|
||
try:
|
||
self.init_query()
|
||
|
||
if not self.query or self.query == "":
|
||
continue
|
||
|
||
if self.query.lower() == "help":
|
||
self.process_help()
|
||
continue
|
||
|
||
if self.query.strip().lower() == "save":
|
||
self.process_save()
|
||
continue
|
||
|
||
if self.query.strip().lower() == "stats":
|
||
print("\n<<< Статистика:")
|
||
self.print_stats()
|
||
continue
|
||
|
||
if self.query.strip().lower() == "exit":
|
||
self.print_v(text="\n*** Завершение работы")
|
||
sys.exit(0)
|
||
|
||
context_docs = self.find_docs(self.query, self.args.topk, self.args.qdrant_collection)
|
||
if not context_docs:
|
||
if args.interactive:
|
||
print("<<< Релевантные документы не найдены")
|
||
self.query = None
|
||
self.args.query = None
|
||
continue
|
||
else:
|
||
break
|
||
|
||
ranked_docs = self.rank_docs(context_docs, self.args.topn, self.args.min_rank_score)
|
||
if not ranked_docs:
|
||
if args.interactive:
|
||
print("<<< Документы были отсеяны полностью")
|
||
#TODO сделать ещё 2 попытки перезапроса+реранка других документов без учёта нерелевантных context_docs
|
||
self.query = None
|
||
self.args.query = None
|
||
continue
|
||
else:
|
||
break
|
||
|
||
sys_prompt = self.prepare_sys_prompt(self.query, ranked_docs)
|
||
if self.args.show_prompt:
|
||
self.show_prompt(sys_prompt)
|
||
|
||
try:
|
||
answer = self.process_query(sys_prompt, self.query, self.args.stream)
|
||
except KeyboardInterrupt:
|
||
print("\n*** Генерация ответа прервана")
|
||
self.query = None
|
||
self.args.query = None
|
||
print(self.prepare_cli_sources(ranked_docs))
|
||
if self.args.show_stats:
|
||
print("\nСтатистика:")
|
||
self.print_stats()
|
||
continue
|
||
|
||
print(self.prepare_cli_sources(ranked_docs))
|
||
|
||
if self.args.show_stats:
|
||
print("\nСтатистика:")
|
||
self.print_stats()
|
||
|
||
self.rag.conversation_history.append({
|
||
"role": "user",
|
||
"content": self.query,
|
||
})
|
||
|
||
self.rag.conversation_history.append({
|
||
"role": "assistant",
|
||
"docs": ranked_docs,
|
||
"content": answer,
|
||
})
|
||
|
||
if args.interactive:
|
||
self.query = None
|
||
self.args.query = None
|
||
else:
|
||
break
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n*** Завершение работы")
|
||
break
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="RAG-система с использованием Ollama и Qdrant")
|
||
parser.add_argument("--query", type=str, help="Запрос к RAG")
|
||
parser.add_argument("--interactive", default=DEFAULT_INTERACTIVE, action=argparse.BooleanOptionalAction, help="Включить интерактивный режим диалога")
|
||
parser.add_argument("--stream", default=DEFAULT_STREAM, action=argparse.BooleanOptionalAction, help="Включить потоковый вывод")
|
||
parser.add_argument("--sys-prompt", type=str, help="Путь к файлу шаблона системного промпта")
|
||
parser.add_argument("--show-prompt", default=DEFAULT_SHOW_PROMPT, action=argparse.BooleanOptionalAction, help="Показать сист. промпт перед запросом")
|
||
parser.add_argument("--verbose", default=DEFAULT_VERBOSE, action=argparse.BooleanOptionalAction, help="Выводить служебные сообщения")
|
||
parser.add_argument("--show-stats", default=DEFAULT_SHOW_STATS, action=argparse.BooleanOptionalAction, help="Выводить статистику об ответе (не работает с --stream)")
|
||
parser.add_argument("--qdrant-host", default=DEFAULT_QDRANT_HOST, help="Адрес хоста Qdrant")
|
||
parser.add_argument("--qdrant-port", type=int, default=DEFAULT_QDRANT_PORT, help="Номер порта Qdrant")
|
||
parser.add_argument("--qdrant-collection", type=str, default=DEFAULT_QDRANT_COLLECTION, help="Название коллекции для поиска документов")
|
||
parser.add_argument("--ollama-url", default=DEFAULT_OLLAMA_URL, help="Ollama API URL")
|
||
parser.add_argument("--chat-model", default=DEFAULT_CHAT_MODEL, help="Модель генерации Ollama")
|
||
parser.add_argument("--emb-model", default=DEFAULT_EMBED_MODEL, help="Модель эмбеддинга")
|
||
parser.add_argument("--topk", type=int, default=DEFAULT_TOP_K, help="Количество документов для поиска")
|
||
parser.add_argument("--use-rank", default=DEFAULT_USE_RANK, action=argparse.BooleanOptionalAction, help="Включить ранжирование")
|
||
parser.add_argument("--rank-model", type=str, default=DEFAULT_RANK_MODEL, help="Модель ранжирования")
|
||
parser.add_argument("--min-rank-score", type=int, default=DEFAULT_MIN_RANK_SCORE, help="Минимальный ранк документа")
|
||
parser.add_argument("--topn", type=int, default=DEFAULT_TOP_N, help="Количество документов после ранжирования")
|
||
args = parser.parse_args()
|
||
|
||
app = App(args)
|
||
app.process()
|