diff --git a/rag/chat.py b/rag/chat.py new file mode 100644 index 0000000..ccca65e --- /dev/null +++ b/rag/chat.py @@ -0,0 +1,505 @@ +import os +import requests +import json +import time +import sys +from qdrant_client import QdrantClient +from sentence_transformers import SentenceTransformer, CrossEncoder + +DEFAULT_CHAT_MODEL = "openchat:7b" +DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" +# DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" +DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" +# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-TinyBERT-L-2-v2" +DEFAULT_MD_FOLDER = "data" +DEFAULT_OLLAMA_URL = "http://localhost:11434" +DEFAULT_QDRANT_HOST = "localhost" +DEFAULT_QDRANT_PORT = 6333 +DEFAULT_QDRANT_COLLECTION = "rag" +DEFAULT_TOP_K = 30 +DEFAULT_USE_RANK = False +DEFAULT_TOP_N = 8 +DEFAULT_VERBOSE = False +DEFAULT_SHOW_STATS = False +DEFAULT_STREAM = False +DEFAULT_INTERACTIVE = False +DEFAULT_SHOW_PROMPT = False +DEFAULT_MIN_RANK_SCORE = 0 + +class RagSystem: + def __init__(self, + ollama_url: str = DEFAULT_OLLAMA_URL, + qdrant_host: str = DEFAULT_QDRANT_HOST, + qdrant_port: int = DEFAULT_QDRANT_PORT, + embed_model: str = DEFAULT_EMBED_MODEL, + rank_model: str = DEFAULT_RANK_MODEL, + use_rank: bool = DEFAULT_USE_RANK, + chat_model: str = DEFAULT_CHAT_MODEL): + self.ollama_url = ollama_url + self.qdrant_host = qdrant_host + self.qdrant_port = qdrant_port + self.chat_model = chat_model + self.emb_model = SentenceTransformer(embed_model) + self.qdrant = QdrantClient(host=args.qdrant_host, port=args.qdrant_port) + self.use_rank = use_rank + if self.use_rank: + self.rank_model = CrossEncoder(rank_model) + self.conversation_history = [] + + self.load_chat_model() + + def load_chat_model(self): + url = f"{self.ollama_url}/api/generate" + body = {"model": self.chat_model} + requests.post(url, json=body, timeout=600) + + def search_qdrant(self, query: str, doc_count: int = DEFAULT_TOP_K, collection_name = DEFAULT_QDRANT_COLLECTION): + query_vec = self.emb_model.encode(query, show_progress_bar=False).tolist() + results = self.qdrant.query_points( + collection_name=collection_name, + query=query_vec, + limit=doc_count, + # score_threshold=0.5, + ) + docs = [] + for point in results.points: + docs.append({ + "payload": point.payload, + "score": point.score, + }) + return docs + + def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE): + if not self.use_rank: + return documents + + pairs = [[query, doc["payload"]["text"]] for doc in documents] + scores = self.rank_model.predict(pairs) + + ranked_docs = [] + for i, doc in enumerate(documents): + score = float(scores[i]) + doc["rank_score"] = score + if score >= min_score: + ranked_docs.append(doc) + + ranked_docs.sort(key=lambda x: x['rank_score'], reverse=True) + return ranked_docs[:top_n] + + def generate_answer(self, sys_prompt: str, user_prompt: str): + url = f"{self.ollama_url}/api/chat" + body = { + "model": self.chat_model, + # "system": sys_prompt, + # "prompt": user_prompt, + "messages": self.conversation_history, + "stream": False, + "options": { + "temperature": 0.5, + # "top_p": 0.2, + }, + } + + response = requests.post(url, json=body, timeout=900) + if response.status_code != 200: + return f"Ошибка генерации ответа: {response.status_code} {response.text}" + self.response = response.json() + return self.response["message"]["content"] + + def generate_answer_stream(self, sys_prompt: str, user_prompt: str): + url = f"{self.ollama_url}/api/chat" + body = { + "model": self.chat_model, + # "system": sys_prompt, + # "prompt": user_prompt, + "messages": self.conversation_history, + "stream": True, + "options": { + "temperature": 0.5, + # "top_p": 0.2, + }, + } + resp = requests.post(url, json=body, stream=True, timeout=900) + if resp.status_code != 200: + raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}") + + answer = "" + for chunk in resp.iter_lines(): + if chunk: + try: + decoded_chunk = chunk.decode('utf-8') + data = json.loads(decoded_chunk) + if "response" in data: + yield data["response"] + answer += data["response"] + if "done" in data and data["done"] is True: + self.response = data + break + elif "error" in data: + answer += f" | Ошибка стриминга ответа: {data['error']}" + break + except json.JSONDecodeError as e: + answer += f" | Ошибка конвертации чанка: {chunk.decode('utf-8')} - {e}" + except Exception as e: + answer += f" | Ошибка обработки чанка: {e}" + + def get_prompt_eval_count(self): + if not self.response["prompt_eval_count"]: + return 0 + return self.response["prompt_eval_count"] + + def get_prompt_eval_duration(self): + if not self.response["prompt_eval_duration"]: + return 0 + return self.response["prompt_eval_duration"] / (10 ** 9) + + def get_eval_count(self): + if not self.response["eval_count"]: + return 0 + return self.response["eval_count"] + + def get_eval_duration(self): + if not self.response["eval_duration"]: + return 0 + return self.response["eval_duration"] / (10 ** 9) + + def get_total_duration(self): + if not self.response["total_duration"]: + return 0 + return self.response["total_duration"] / (10 ** 9) + + def get_tps(self): + eval_count = self.get_eval_count() + eval_duration = self.get_eval_duration() + if eval_count == 0 or eval_duration == 0: + return 0 + return eval_count / eval_duration + +class App: + def __init__( + self, + args: list = [] + ): + if not args.query and not args.interactive: + print("Ошибка: укажите запрос (--query) и/или используйте интерактивный режим (--interactive)") + sys.exit(1) + + self.args = args + self.print_v(text=f"Включить интерактивный режим диалога: {args.interactive}") + self.print_v(text=f"Включить потоковый вывод: {args.stream}") + if self.is_custom_sys_prompt(): + self.print_v(text=f"Системный промпт: {args.sys_prompt}") + else: + self.print_v(text=f"Системный промпт: по умолчанию") + self.print_v(text=f"Показать сист. промпт перед запросом: {args.show_prompt}") + self.print_v(text=f"Выводить служебные сообщения: {args.verbose}") + self.print_v(text=f"Выводить статистику об ответе: {args.show_stats}") + self.print_v(text=f"Адрес хоста Qdrant: {args.qdrant_host}") + self.print_v(text=f"Номер порта Qdrant: {args.qdrant_port}") + self.print_v(text=f"Название коллекции для поиска документов: {args.qdrant_collection}") + self.print_v(text=f"Ollama API URL: {args.ollama_url}") + self.print_v(text=f"Модель генерации Ollama: {args.chat_model}") + self.print_v(text=f"Модель эмбеддинга: {args.emb_model}") + self.print_v(text=f"Количество документов для поиска: {args.topk}") + self.print_v(text=f"Включить ранжирование: {args.use_rank}") + self.print_v(text=f"Модель ранжирования: {args.rank_model}") + self.print_v(text=f"Количество документов после ранжирования: {args.topn}") + self.init_rag() + + def print_v(self, text: str = "\n"): + if self.args.verbose: + print(f"{text}") + + def init_rag(self): + self.print_v(text="\nИнициализация моделей...") + self.rag = RagSystem( + ollama_url = self.args.ollama_url, + qdrant_host = self.args.qdrant_host, + qdrant_port = self.args.qdrant_port, + embed_model = self.args.emb_model, + rank_model = self.args.rank_model, + use_rank = self.args.use_rank, + chat_model = self.args.chat_model + ) + self.print_v(text=f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG") + + def init_query(self): + self.query = None + if args.interactive: + self.print_v(text="\nИНТЕРАКТИВНЫЙ РЕЖИМ") + self.print_v(text="Можете вводить запрос (или 'exit' для выхода)\n") + + if self.args.query: + self.query = self.args.query.strip() + print(f">>> {self.query}") + elif args.interactive: + self.query = input(">>> ").strip() + + def process_help(self): + print("<<< Команды итерактивного режима:") + print("save -- сохранить диалог в файл") + print("exit -- выход\n") + self.query = None + self.args.query = None + + def process_save(self): + import datetime + timestamp = int(time.time()) + dt = datetime.datetime.fromtimestamp(timestamp).strftime('%Y-%m-%dT%H:%M:%SZ') + filename = f"chats/chat-{timestamp}-{self.args.chat_model}.md" + + markdown_content = f"# История диалога от {dt}\n\n" + markdown_content += f"## Параметры диалога\n" + markdown_content += f"```\nargs = {self.args}\n```\n" + markdown_content += f"```\nemb_model = {self.rag.emb_model}\n```\n" + markdown_content += f"```\nrank_model = {self.rag.rank_model}\n```\n" + + for entry in self.rag.conversation_history: + if entry['role'] == 'user': + markdown_content += f"## Пользователь\n\n" + elif entry['role'] == 'assistant': + markdown_content += f"## Модель\n\n" + docs = self.rag.prepare_ctx_sources(entry['docs']).replace("```", "") + markdown_content += f"```\n{docs}\n```\n\n" + markdown_content += f"{entry['content']}\n\n" + + os.makedirs('chats', exist_ok=True) + with open(filename, 'w') as fp: + fp.write(markdown_content) + + print(f"<<< Диалог сохранён в файл: {filename}\n") + self.query = None + + def find_docs(self, query: str, top_k: int, collection_name: str): + self.print_v(text="\nПоиск документов...") + context_docs = self.rag.search_qdrant(query, top_k, collection_name) + self.print_v(text=f"Найдено {len(context_docs)} документов") + return context_docs + + def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE): + self.print_v(text="\nРанжирование документов...") + ranked_docs = self.rag.rank_documents(self.query, docs, top_n, min_score) + self.print_v(text=f"После ранжирования осталось {len(ranked_docs)} документов") + return ranked_docs + + def prepare_ctx_sources(self, docs: list): + sources = "" + for idx, doc in enumerate(docs, start=1): + text = doc['payload'].get("text", "").strip() + sources = f"{sources}\n\n{text}\n\n" + return sources + + def prepare_cli_sources(self, docs: list): + sources = "\nИсточники:\n" + for idx, doc in enumerate(docs, start=1): + title = doc['payload'].get("filename", None) + url = doc['payload'].get("url", None) + date = doc['payload'].get("date", None) + version = doc['payload'].get("version", None) + author = doc['payload'].get("author", None) + + if url is None: + url = "(нет веб-ссылки)" + if date is None: + date = "(неизвестно)" + if version is None: + version = "0" + if author is None: + author = "(неизвестен)" + + sources += f"{idx}. {title}\n" + sources += f" {url}\n" + sources += f" Версия {version} от {author}, актуальная на {date}\n" + if doc['rank_score']: + sources += f" score = {doc['score']} | rank_score = {doc['rank_score']}\n" + else: + sources += f" score = {doc['score']}\n" + return sources + + def prepare_sys_prompt(self, query: str, docs: list): + if self.is_custom_sys_prompt(): + with open(self.args.sys_prompt, 'r') as fp: + prompt = fp.read() + else: + prompt = """You are a helpful assistant that can answer questions based on the provided context. +Your user is the person asking the source-related question. +Your job is to answer the question based on the context alone. +If the context doesn't provide much information, answer "I don't know." +Adhere to this in all languages. + +Context: + +----------------------------------------- +{{sources}} +----------------------------------------- +""" + + sources = self.prepare_ctx_sources(docs) + return prompt.replace("{{sources}}", sources).replace("{{query}}", query) + + def show_prompt(self, sys_prompt: str): + print("\n================ Системный промпт ==================") + print(f"{sys_prompt}\n============ Конец системного промпта ==============\n") + + def process_query(self, sys_prompt: str, user_prompt: str, streaming: bool = DEFAULT_STREAM): + answer = "" + # try: + if streaming: + self.print_v(text="\nГенерация потокового ответа (^C для остановки)...\n") + print(f"<<< ", end='', flush=True) + for token in self.rag.generate_answer_stream(sys_prompt, user_prompt): + answer += token + print(token, end='', flush=True) + else: + self.print_v(text="\nГенерация ответа (^C для остановки)...\n") + answer = self.rag.generate_answer(sys_prompt, user_prompt) + print(f"<<< {answer}\n") + # except RuntimeError as e: + # answer = str(e) + + print(f"\n===================================================") + return answer + + def is_custom_sys_prompt(self): + return self.args.sys_prompt and os.path.exists(self.args.sys_prompt) + + def print_stats(self): + print(f"* Time: {self.rag.get_total_duration()}s") + print(f"* TPS: {self.rag.get_tps()}") + print(f"* PEC: {self.rag.get_prompt_eval_count()}") + print(f"* PED: {self.rag.get_prompt_eval_duration()}s") + print(f"* EC: {self.rag.get_eval_count()}") + print(f"* ED: {self.rag.get_eval_duration()}s\n") + self.query = None + self.args.query = None + + def process(self): + while True: + try: + self.init_query() + + if not self.query or self.query == "": + continue + + if self.query.lower() == "help": + self.process_help() + continue + + if self.query.strip().lower() == "save": + self.process_save() + continue + + if self.query.strip().lower() == "stats": + print("\n<<< Статистика:") + self.print_stats() + continue + + if self.query.strip().lower() == "exit": + self.print_v(text="\n*** Завершение работы") + sys.exit(0) + + context_docs = self.find_docs(self.query, self.args.topk, self.args.qdrant_collection) + if not context_docs: + if args.interactive: + print("<<< Релевантные документы не найдены") + self.query = None + self.args.query = None + continue + else: + break + + ranked_docs = self.rank_docs(context_docs, self.args.topn, self.args.min_rank_score) + if not ranked_docs: + if args.interactive: + print("<<< Документы были отсеяны полностью") + #TODO сделать ещё 2 попытки перезапроса+реранка других документов без учёта нерелевантных context_docs + self.query = None + self.args.query = None + continue + else: + break + + # ctx = self.prepare_ctx_sources(ranked_docs) + sys_prompt = self.prepare_sys_prompt(self.query, ranked_docs) + if self.args.show_prompt: + self.show_prompt(sys_prompt) + + # self.rag.conversation_history.append({ + # "role": "system", + # "content": sys_prompt, + # }) + + self.rag.conversation_history.append({ + "role": "system", + "content": sys_prompt, + }) + + self.rag.conversation_history.append({ + "role": "user", + "content": self.query, + }) + + try: + answer = self.process_query(sys_prompt, self.query, self.args.stream) + except KeyboardInterrupt: + print("\n*** Генерация ответа прервана") + self.query = None + self.args.query = None + print(self.prepare_cli_sources(ranked_docs)) + if self.args.show_stats: + print("\nСтатистика:") + self.print_stats() + continue + + print(self.prepare_cli_sources(ranked_docs)) + + if self.args.show_stats: + print("\nСтатистика:") + self.print_stats() + + self.rag.conversation_history.append({ + "role": "assistant", + "docs": ranked_docs, + "content": answer, + }) + + if args.interactive: + self.query = None + self.args.query = None + else: + break + + except KeyboardInterrupt: + print("\n*** Завершение работы") + break + + except Exception as e: + print(f"Ошибка: {e}") + break + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="RAG-система с использованием Ollama и Qdrant") + parser.add_argument("--query", type=str, help="Запрос к RAG") + parser.add_argument("--interactive", default=DEFAULT_INTERACTIVE, action=argparse.BooleanOptionalAction, help="Включить интерактивный режим диалога") + parser.add_argument("--stream", default=DEFAULT_STREAM, action=argparse.BooleanOptionalAction, help="Включить потоковый вывод") + parser.add_argument("--sys-prompt", type=str, help="Путь к файлу шаблона системного промпта") + parser.add_argument("--show-prompt", default=DEFAULT_SHOW_PROMPT, action=argparse.BooleanOptionalAction, help="Показать сист. промпт перед запросом") + parser.add_argument("--verbose", default=DEFAULT_VERBOSE, action=argparse.BooleanOptionalAction, help="Выводить служебные сообщения") + parser.add_argument("--show-stats", default=DEFAULT_SHOW_STATS, action=argparse.BooleanOptionalAction, help="Выводить статистику об ответе (не работает с --stream)") + parser.add_argument("--qdrant-host", default=DEFAULT_QDRANT_HOST, help="Адрес хоста Qdrant") + parser.add_argument("--qdrant-port", type=int, default=DEFAULT_QDRANT_PORT, help="Номер порта Qdrant") + parser.add_argument("--qdrant-collection", type=str, default=DEFAULT_QDRANT_COLLECTION, help="Название коллекции для поиска документов") + parser.add_argument("--ollama-url", default=DEFAULT_OLLAMA_URL, help="Ollama API URL") + parser.add_argument("--chat-model", default=DEFAULT_CHAT_MODEL, help="Модель генерации Ollama") + parser.add_argument("--emb-model", default=DEFAULT_EMBED_MODEL, help="Модель эмбеддинга") + parser.add_argument("--topk", type=int, default=DEFAULT_TOP_K, help="Количество документов для поиска") + parser.add_argument("--use-rank", default=DEFAULT_USE_RANK, action=argparse.BooleanOptionalAction, help="Включить ранжирование") + parser.add_argument("--rank-model", type=str, default=DEFAULT_RANK_MODEL, help="Модель ранжирования") + parser.add_argument("--min-rank-score", type=int, default=DEFAULT_MIN_RANK_SCORE, help="Минимальный ранк документа") + parser.add_argument("--topn", type=int, default=DEFAULT_TOP_N, help="Количество документов после ранжирования") + args = parser.parse_args() + + app = App(args) + app.process()