diff --git a/rag/README.md b/rag/README.md index 142f28c..abf40df 100644 --- a/rag/README.md +++ b/rag/README.md @@ -219,8 +219,8 @@ python3 rag.py --help ### Ранжирование -- [`cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1`](https://hf.co/cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1) ☑️ -- [`cross-encoder/ms-marco-MiniLM-L-6-v2`](https://hf.co/cross-encoder/ms-marco-MiniLM-L-6-v2) +- [`cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1`](https://hf.co/cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1) +- [`cross-encoder/ms-marco-MiniLM-L-6-v2`](https://hf.co/cross-encoder/ms-marco-MiniLM-L-6-v2) ☑️ - [`cross-encoder/ms-marco-TinyBERT-L-2-v2`](https://hf.co/cross-encoder/ms-marco-TinyBERT-L-2-v2) - ... diff --git a/rag/rag.py b/rag/rag.py index 21367a4..09f35f9 100644 --- a/rag/rag.py +++ b/rag/rag.py @@ -6,10 +6,10 @@ import sys from qdrant_client import QdrantClient from sentence_transformers import SentenceTransformer, CrossEncoder -DEFAULT_CHAT_MODEL = "phi4-mini:3.8b" +DEFAULT_CHAT_MODEL = "openchat:7b" DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" -DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" -# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" +# DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" +DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" # DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-TinyBERT-L-2-v2" DEFAULT_MD_FOLDER = "data" DEFAULT_OLLAMA_URL = "http://localhost:11434" @@ -24,6 +24,7 @@ DEFAULT_SHOW_STATS = False DEFAULT_STREAM = False DEFAULT_INTERACTIVE = False DEFAULT_SHOW_PROMPT = False +DEFAULT_MIN_RANK_SCORE = 0 class RagSystem: def __init__(self, @@ -44,7 +45,6 @@ class RagSystem: if self.use_rank: self.rank_model = CrossEncoder(rank_model) self.conversation_history = [] - self.load_chat_model() def load_chat_model(self): @@ -68,18 +68,22 @@ class RagSystem: }) return docs - def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N): + def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE): if not self.use_rank: return documents pairs = [[query, doc["payload"]["text"]] for doc in documents] scores = self.rank_model.predict(pairs) + ranked_docs = [] for i, doc in enumerate(documents): - doc["rank_score"] = float(scores[i]) + score = float(scores[i]) + doc["rank_score"] = score + if score >= min_score: + ranked_docs.append(doc) - documents.sort(key=lambda x: x['rank_score'], reverse=True) - return documents[:top_n] + ranked_docs.sort(key=lambda x: x['rank_score'], reverse=True) + return ranked_docs[:top_n] def generate_answer(self, sys_prompt: str, user_prompt: str): url = f"{self.ollama_url}/api/generate" @@ -87,7 +91,6 @@ class RagSystem: "model": self.chat_model, "system": sys_prompt, "prompt": user_prompt, - #"context": self.conversation_history, "stream": False, "options": { "temperature": 0.5, @@ -107,11 +110,10 @@ class RagSystem: "model": self.chat_model, "system": sys_prompt, "prompt": user_prompt, - #"context": self.conversation_history, "stream": True, "options": { - "temperature": 0.1, - "top_p": 0.2, + "temperature": 0.5, + # "top_p": 0.2, }, } resp = requests.post(url, json=body, stream=True, timeout=900) @@ -119,6 +121,7 @@ class RagSystem: raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}") answer = "" + self.response = None for chunk in resp.iter_lines(): if chunk: try: @@ -139,27 +142,27 @@ class RagSystem: answer += f" | Ошибка обработки чанка: {e}" def get_prompt_eval_count(self): - if not self.response["prompt_eval_count"]: + if not self.response: return 0 return self.response["prompt_eval_count"] def get_prompt_eval_duration(self): - if not self.response["prompt_eval_duration"]: + if not self.response: return 0 return self.response["prompt_eval_duration"] / (10 ** 9) def get_eval_count(self): - if not self.response["eval_count"]: + if not self.response: return 0 return self.response["eval_count"] def get_eval_duration(self): - if not self.response["eval_duration"]: + if not self.response: return 0 return self.response["eval_duration"] / (10 ** 9) def get_total_duration(self): - if not self.response["total_duration"]: + if not self.response: return 0 return self.response["total_duration"] / (10 ** 9) @@ -271,9 +274,9 @@ class App: self.print_v(text=f"Найдено {len(context_docs)} документов") return context_docs - def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N): + def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE): self.print_v(text="\nРанжирование документов...") - ranked_docs = self.rag.rank_documents(self.query, docs, top_n) + ranked_docs = self.rag.rank_documents(self.query, docs, top_n, min_score) self.print_v(text=f"После ранжирования осталось {len(ranked_docs)} документов") return ranked_docs @@ -403,10 +406,11 @@ Context: else: break - ranked_docs = self.rank_docs(context_docs, self.args.topn) + ranked_docs = self.rank_docs(context_docs, self.args.topn, self.args.min_rank_score) if not ranked_docs: if args.interactive: - print("<<< Релевантные документы были отсеяны полностью") + print("<<< Документы были отсеяны полностью") + #TODO сделать ещё 2 попытки перезапроса+реранка других документов без учёта нерелевантных context_docs self.query = None self.args.query = None continue @@ -456,10 +460,6 @@ Context: print("\n*** Завершение работы") break - except Exception as e: - print(f"Ошибка: {e}") - break - if __name__ == "__main__": import argparse @@ -480,6 +480,7 @@ if __name__ == "__main__": parser.add_argument("--topk", type=int, default=DEFAULT_TOP_K, help="Количество документов для поиска") parser.add_argument("--use-rank", default=DEFAULT_USE_RANK, action=argparse.BooleanOptionalAction, help="Включить ранжирование") parser.add_argument("--rank-model", type=str, default=DEFAULT_RANK_MODEL, help="Модель ранжирования") + parser.add_argument("--min-rank-score", type=int, default=DEFAULT_MIN_RANK_SCORE, help="Минимальный ранк документа") parser.add_argument("--topn", type=int, default=DEFAULT_TOP_N, help="Количество документов после ранжирования") args = parser.parse_args()