import os import requests import json import time import sys from qdrant_client import QdrantClient from sentence_transformers import SentenceTransformer, CrossEncoder DEFAULT_CHAT_MODEL = "openchat:7b" DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" DEFAULT_MD_FOLDER = "data" DEFAULT_OLLAMA_URL = "http://localhost:11434" DEFAULT_QDRANT_HOST = "localhost" DEFAULT_QDRANT_PORT = 6333 DEFAULT_QDRANT_COLLECTION = "rag" DEFAULT_TOP_K = 30 DEFAULT_USE_RANK = False DEFAULT_TOP_N = 8 DEFAULT_VERBOSE = False DEFAULT_SHOW_STATS = False DEFAULT_STREAM = False DEFAULT_INTERACTIVE = False DEFAULT_SHOW_PROMPT = False DEFAULT_MIN_RANK_SCORE = 0 class RagSystem: def __init__(self, ollama_url: str = DEFAULT_OLLAMA_URL, qdrant_host: str = DEFAULT_QDRANT_HOST, qdrant_port: int = DEFAULT_QDRANT_PORT, embed_model: str = DEFAULT_EMBED_MODEL, rank_model: str = DEFAULT_RANK_MODEL, use_rank: bool = DEFAULT_USE_RANK, chat_model: str = DEFAULT_CHAT_MODEL): self.ollama_url = ollama_url self.qdrant_host = qdrant_host self.qdrant_port = qdrant_port self.chat_model = chat_model self.emb_model = SentenceTransformer(embed_model) self.qdrant = QdrantClient(host=args.qdrant_host, port=args.qdrant_port) self.use_rank = use_rank if self.use_rank: self.rank_model = CrossEncoder(rank_model) self.conversation_history = [] def check_chat_model(self): response = requests.get(f"{self.ollama_url}/api/tags") if response.status_code != 200: return False for model in response.json().get("models", []): if model["name"] == self.chat_model: return True return False def install_chat_model(self, model: str = DEFAULT_CHAT_MODEL): try: response = requests.post(f"{self.ollama_url}/api/pull", json={"model": model}) if response.status_code == 200: print(f"Модель {self.chat_model} установлена успешно") else: print(f"Ошибка установки модели: {response.text}") except Exception as e: print(f"Ошибка проверки модели: {str(e)}") def load_chat_model(self): requests.post(f"{self.ollama_url}/api/generate", json={"model": self.chat_model}, timeout=600) def search_qdrant(self, query: str, doc_count: int = DEFAULT_TOP_K, collection_name = DEFAULT_QDRANT_COLLECTION): query_vec = self.emb_model.encode(query, show_progress_bar=False).tolist() results = self.qdrant.query_points( collection_name=collection_name, query=query_vec, limit=doc_count, # score_threshold=0.5, ) docs = [] for point in results.points: docs.append({ "payload": point.payload, "score": point.score, }) return docs def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE): if not self.use_rank: return documents pairs = [[query, doc["payload"]["text"]] for doc in documents] scores = self.rank_model.predict(pairs) ranked_docs = [] for i, doc in enumerate(documents): score = float(scores[i]) doc["rank_score"] = score if score >= min_score: ranked_docs.append(doc) ranked_docs.sort(key=lambda x: x['rank_score'], reverse=True) return ranked_docs[:top_n] def generate_answer(self, sys_prompt: str, user_prompt: str): url = f"{self.ollama_url}/api/generate" body = { "model": self.chat_model, "system": sys_prompt, "prompt": user_prompt, "stream": False, "options": { "temperature": 0.5, # "top_p": 0.2, }, } response = requests.post(url, json=body, timeout=900) if response.status_code != 200: return f"Ошибка генерации ответа: {response.status_code} {response.text}" self.response = response.json() return self.response["response"] def generate_answer_stream(self, sys_prompt: str, user_prompt: str): url = f"{self.ollama_url}/api/generate" body = { "model": self.chat_model, "system": sys_prompt, "prompt": user_prompt, "stream": True, "options": { "temperature": 0.5, # "top_p": 0.2, }, } resp = requests.post(url, json=body, stream=True, timeout=900) if resp.status_code != 200: raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}") answer = "" self.response = None for chunk in resp.iter_lines(): if chunk: try: decoded_chunk = chunk.decode('utf-8') data = json.loads(decoded_chunk) if "response" in data: yield data["response"] answer += data["response"] if "done" in data and data["done"] is True: self.response = data break elif "error" in data: answer += f" | Ошибка стриминга ответа: {data['error']}" break except json.JSONDecodeError as e: answer += f" | Ошибка конвертации чанка: {chunk.decode('utf-8')} - {e}" except Exception as e: answer += f" | Ошибка обработки чанка: {e}" def get_prompt_eval_count(self): if not self.response: return 0 return self.response["prompt_eval_count"] def get_prompt_eval_duration(self): if not self.response: return 0 return self.response["prompt_eval_duration"] / (10 ** 9) def get_eval_count(self): if not self.response: return 0 return self.response["eval_count"] def get_eval_duration(self): if not self.response: return 0 return self.response["eval_duration"] / (10 ** 9) def get_total_duration(self): if not self.response: return 0 return self.response["total_duration"] / (10 ** 9) def get_tps(self): eval_count = self.get_eval_count() eval_duration = self.get_eval_duration() if eval_count == 0 or eval_duration == 0: return 0 return eval_count / eval_duration class App: def __init__( self, args: list = [] ): if not args.query and not args.interactive: print("Ошибка: укажите запрос (--query) и/или используйте интерактивный режим (--interactive)") sys.exit(1) self.args = args self.print_v(text=f"Включить интерактивный режим диалога: {args.interactive}") self.print_v(text=f"Включить потоковый вывод: {args.stream}") if self.is_custom_sys_prompt(): self.print_v(text=f"Системный промпт: {args.sys_prompt}") else: self.print_v(text=f"Системный промпт: по умолчанию") self.print_v(text=f"Показать сист. промпт перед запросом: {args.show_prompt}") self.print_v(text=f"Выводить служебные сообщения: {args.verbose}") self.print_v(text=f"Выводить статистику об ответе: {args.show_stats}") self.print_v(text=f"Адрес хоста Qdrant: {args.qdrant_host}") self.print_v(text=f"Номер порта Qdrant: {args.qdrant_port}") self.print_v(text=f"Название коллекции для поиска документов: {args.qdrant_collection}") self.print_v(text=f"Ollama API URL: {args.ollama_url}") self.print_v(text=f"Модель генерации Ollama: {args.chat_model}") self.print_v(text=f"Модель эмбеддинга: {args.emb_model}") self.print_v(text=f"Количество документов для поиска: {args.topk}") self.print_v(text=f"Включить ранжирование: {args.use_rank}") self.print_v(text=f"Модель ранжирования: {args.rank_model}") self.print_v(text=f"Количество документов после ранжирования: {args.topn}") self.init_rag() def print_v(self, text: str = "\n"): if self.args.verbose: print(f"{text}") def init_rag(self): self.print_v(text="\nИнициализация моделей...") self.rag = RagSystem( ollama_url = self.args.ollama_url, qdrant_host = self.args.qdrant_host, qdrant_port = self.args.qdrant_port, embed_model = self.args.emb_model, rank_model = self.args.rank_model, use_rank = self.args.use_rank, chat_model = self.args.chat_model ) if not self.rag.check_chat_model(): print(f"Установка модели {self.args.chat_model} ...") self.rag.install_chat_model(self.args.chat_model) self.rag.load_chat_model() self.print_v(text=f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG") def init_query(self): self.query = None if args.interactive: self.print_v(text="\nИНТЕРАКТИВНЫЙ РЕЖИМ") self.print_v(text="Можете вводить запрос (или 'exit' для выхода)\n") if self.args.query: self.query = self.args.query.strip() print(f">>> {self.query}") elif args.interactive: self.query = input(">>> ").strip() def process_help(self): print("<<< Команды интерактивного режима:") print("save -- сохранить диалог в файл") print("stats -- статистика последнего ответа") print("exit -- выход\n") self.query = None self.args.query = None def process_save(self): import datetime timestamp = int(time.time()) dt = datetime.datetime.fromtimestamp(timestamp).strftime('%Y-%m-%dT%H:%M:%SZ') filename = f"chats/chat-{timestamp}-{self.args.chat_model}.md" markdown_content = f"# История диалога от {dt}\n\n" markdown_content += f"## Параметры диалога\n" markdown_content += f"```\nargs = {self.args}\n```\n" markdown_content += f"```\nemb_model = {self.rag.emb_model}\n```\n" markdown_content += f"```\nrank_model = {self.rag.rank_model}\n```\n" for entry in self.rag.conversation_history: if entry['role'] == 'user': markdown_content += f"## Пользователь\n\n" elif entry['role'] == 'assistant': markdown_content += f"## Модель\n\n" docs = self.rag.prepare_ctx_sources(entry['docs']).replace("```", "") markdown_content += f"```\n{docs}\n```\n\n" markdown_content += f"{entry['content']}\n\n" os.makedirs('chats', exist_ok=True) with open(filename, 'w') as fp: fp.write(markdown_content) print(f"<<< Диалог сохранён в файл: {filename}\n") self.query = None def find_docs(self, query: str, top_k: int, collection_name: str): self.print_v(text="\nПоиск документов...") context_docs = self.rag.search_qdrant(query, top_k, collection_name) self.print_v(text=f"Найдено {len(context_docs)} документов") return context_docs def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE): self.print_v(text="\nРанжирование документов...") ranked_docs = self.rag.rank_documents(self.query, docs, top_n, min_score) self.print_v(text=f"После ранжирования осталось {len(ranked_docs)} документов") return ranked_docs def prepare_ctx_sources(self, docs: list): sources = "" for idx, doc in enumerate(docs, start=1): text = doc['payload'].get("text", "").strip() sources = f"{sources}\n\n{text}\n\n" return sources def prepare_cli_sources(self, docs: list): sources = "\nИсточники:\n" for idx, doc in enumerate(docs, start=1): title = doc['payload'].get("filename", None) url = doc['payload'].get("url", None) date = doc['payload'].get("date", None) version = doc['payload'].get("version", None) author = doc['payload'].get("author", None) if url is None: url = "(нет веб-ссылки)" if date is None: date = "(неизвестно)" if version is None: version = "0" if author is None: author = "(неизвестен)" sources += f"{idx}. {title}\n" sources += f" {url}\n" sources += f" Версия {version} от {author}, актуальная на {date}\n" if doc['rank_score']: sources += f" score = {doc['score']} | rank_score = {doc['rank_score']}\n" else: sources += f" score = {doc['score']}\n" return sources def prepare_sys_prompt(self, query: str, docs: list): if self.is_custom_sys_prompt(): with open(self.args.sys_prompt, 'r') as fp: prompt_tpl = fp.read() else: prompt_tpl = """You are a helpful assistant that can answer questions based on the provided context. Your user is the person asking the source-related question. Your job is to answer the question based on the context alone. If the context doesn't provide much information, answer "I don't know." Adhere to this in all languages. Context: ----------------------------------------- {{sources}} ----------------------------------------- """ sources = self.prepare_ctx_sources(docs) return prompt_tpl.replace("{{sources}}", sources).replace("{{query}}", query) def show_prompt(self, sys_prompt: str): print("\n================ Системный промпт ==================") print(f"{sys_prompt}\n============ Конец системного промпта ==============\n") def process_query(self, sys_prompt: str, user_prompt: str, streaming: bool = DEFAULT_STREAM): answer = "" # try: if streaming: self.print_v(text="\nГенерация потокового ответа (^C для остановки)...\n") print(f"<<< ", end='', flush=True) for token in self.rag.generate_answer_stream(sys_prompt, user_prompt): answer += token print(token, end='', flush=True) else: self.print_v(text="\nГенерация ответа (^C для остановки)...\n") answer = self.rag.generate_answer(sys_prompt, user_prompt) print(f"<<< {answer}\n") # except RuntimeError as e: # answer = str(e) print(f"\n===================================================") return answer def is_custom_sys_prompt(self): return self.args.sys_prompt and os.path.exists(self.args.sys_prompt) def print_stats(self): print(f"* Time: {self.rag.get_total_duration()}s") print(f"* TPS: {self.rag.get_tps()}") print(f"* PEC: {self.rag.get_prompt_eval_count()}") print(f"* PED: {self.rag.get_prompt_eval_duration()}s") print(f"* EC: {self.rag.get_eval_count()}") print(f"* ED: {self.rag.get_eval_duration()}s\n") self.query = None self.args.query = None def process(self): while True: try: self.init_query() if not self.query or self.query == "": continue if self.query.lower() == "help": self.process_help() continue if self.query.strip().lower() == "save": self.process_save() continue if self.query.strip().lower() == "stats": print("\n<<< Статистика:") self.print_stats() continue if self.query.strip().lower() == "exit": self.print_v(text="\n*** Завершение работы") sys.exit(0) context_docs = self.find_docs(self.query, self.args.topk, self.args.qdrant_collection) if not context_docs: if args.interactive: print("<<< Релевантные документы не найдены") self.query = None self.args.query = None continue else: break ranked_docs = self.rank_docs(context_docs, self.args.topn, self.args.min_rank_score) if not ranked_docs: if args.interactive: print("<<< Документы были отсеяны полностью") #TODO сделать ещё 2 попытки перезапроса+реранка других документов без учёта нерелевантных context_docs self.query = None self.args.query = None continue else: break sys_prompt = self.prepare_sys_prompt(self.query, ranked_docs) if self.args.show_prompt: self.show_prompt(sys_prompt) try: answer = self.process_query(sys_prompt, self.query, self.args.stream) except KeyboardInterrupt: print("\n*** Генерация ответа прервана") self.query = None self.args.query = None print(self.prepare_cli_sources(ranked_docs)) if self.args.show_stats: print("\nСтатистика:") self.print_stats() continue print(self.prepare_cli_sources(ranked_docs)) if self.args.show_stats: print("\nСтатистика:") self.print_stats() self.rag.conversation_history.append({ "role": "user", "content": self.query, }) self.rag.conversation_history.append({ "role": "assistant", "docs": ranked_docs, "content": answer, }) if args.interactive: self.query = None self.args.query = None else: break except KeyboardInterrupt: print("\n*** Завершение работы") break if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="RAG-система с использованием Ollama и Qdrant") parser.add_argument("--query", type=str, help="Запрос к RAG") parser.add_argument("--interactive", default=DEFAULT_INTERACTIVE, action=argparse.BooleanOptionalAction, help="Включить интерактивный режим диалога") parser.add_argument("--stream", default=DEFAULT_STREAM, action=argparse.BooleanOptionalAction, help="Включить потоковый вывод") parser.add_argument("--sys-prompt", type=str, help="Путь к файлу шаблона системного промпта") parser.add_argument("--show-prompt", default=DEFAULT_SHOW_PROMPT, action=argparse.BooleanOptionalAction, help="Показать сист. промпт перед запросом") parser.add_argument("--verbose", default=DEFAULT_VERBOSE, action=argparse.BooleanOptionalAction, help="Выводить служебные сообщения") parser.add_argument("--show-stats", default=DEFAULT_SHOW_STATS, action=argparse.BooleanOptionalAction, help="Выводить статистику об ответе (не работает с --stream)") parser.add_argument("--qdrant-host", default=DEFAULT_QDRANT_HOST, help="Адрес хоста Qdrant") parser.add_argument("--qdrant-port", type=int, default=DEFAULT_QDRANT_PORT, help="Номер порта Qdrant") parser.add_argument("--qdrant-collection", type=str, default=DEFAULT_QDRANT_COLLECTION, help="Название коллекции для поиска документов") parser.add_argument("--ollama-url", default=DEFAULT_OLLAMA_URL, help="Ollama API URL") parser.add_argument("--chat-model", default=DEFAULT_CHAT_MODEL, help="Модель генерации Ollama") parser.add_argument("--emb-model", default=DEFAULT_EMBED_MODEL, help="Модель эмбеддинга") parser.add_argument("--topk", type=int, default=DEFAULT_TOP_K, help="Количество документов для поиска") parser.add_argument("--use-rank", default=DEFAULT_USE_RANK, action=argparse.BooleanOptionalAction, help="Включить ранжирование") parser.add_argument("--rank-model", type=str, default=DEFAULT_RANK_MODEL, help="Модель ранжирования") parser.add_argument("--min-rank-score", type=int, default=DEFAULT_MIN_RANK_SCORE, help="Минимальный ранк документа") parser.add_argument("--topn", type=int, default=DEFAULT_TOP_N, help="Количество документов после ранжирования") args = parser.parse_args() app = App(args) app.process()