1
0

Проба отсева документов по минимальному скору

This commit is contained in:
2025-09-08 09:15:38 +08:00
parent 0106d157d3
commit 1f54ab0409
2 changed files with 28 additions and 27 deletions

View File

@@ -219,8 +219,8 @@ python3 rag.py --help
### Ранжирование ### Ранжирование
- [`cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1`](https://hf.co/cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1) ☑️ - [`cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1`](https://hf.co/cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1)
- [`cross-encoder/ms-marco-MiniLM-L-6-v2`](https://hf.co/cross-encoder/ms-marco-MiniLM-L-6-v2) - [`cross-encoder/ms-marco-MiniLM-L-6-v2`](https://hf.co/cross-encoder/ms-marco-MiniLM-L-6-v2) ☑️
- [`cross-encoder/ms-marco-TinyBERT-L-2-v2`](https://hf.co/cross-encoder/ms-marco-TinyBERT-L-2-v2) - [`cross-encoder/ms-marco-TinyBERT-L-2-v2`](https://hf.co/cross-encoder/ms-marco-TinyBERT-L-2-v2)
- ... - ...

View File

@@ -6,10 +6,10 @@ import sys
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer, CrossEncoder from sentence_transformers import SentenceTransformer, CrossEncoder
DEFAULT_CHAT_MODEL = "phi4-mini:3.8b" DEFAULT_CHAT_MODEL = "openchat:7b"
DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" # DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-TinyBERT-L-2-v2" # DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
DEFAULT_MD_FOLDER = "data" DEFAULT_MD_FOLDER = "data"
DEFAULT_OLLAMA_URL = "http://localhost:11434" DEFAULT_OLLAMA_URL = "http://localhost:11434"
@@ -24,6 +24,7 @@ DEFAULT_SHOW_STATS = False
DEFAULT_STREAM = False DEFAULT_STREAM = False
DEFAULT_INTERACTIVE = False DEFAULT_INTERACTIVE = False
DEFAULT_SHOW_PROMPT = False DEFAULT_SHOW_PROMPT = False
DEFAULT_MIN_RANK_SCORE = 0
class RagSystem: class RagSystem:
def __init__(self, def __init__(self,
@@ -44,7 +45,6 @@ class RagSystem:
if self.use_rank: if self.use_rank:
self.rank_model = CrossEncoder(rank_model) self.rank_model = CrossEncoder(rank_model)
self.conversation_history = [] self.conversation_history = []
self.load_chat_model() self.load_chat_model()
def load_chat_model(self): def load_chat_model(self):
@@ -68,18 +68,22 @@ class RagSystem:
}) })
return docs return docs
def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N): def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE):
if not self.use_rank: if not self.use_rank:
return documents return documents
pairs = [[query, doc["payload"]["text"]] for doc in documents] pairs = [[query, doc["payload"]["text"]] for doc in documents]
scores = self.rank_model.predict(pairs) scores = self.rank_model.predict(pairs)
ranked_docs = []
for i, doc in enumerate(documents): for i, doc in enumerate(documents):
doc["rank_score"] = float(scores[i]) score = float(scores[i])
doc["rank_score"] = score
if score >= min_score:
ranked_docs.append(doc)
documents.sort(key=lambda x: x['rank_score'], reverse=True) ranked_docs.sort(key=lambda x: x['rank_score'], reverse=True)
return documents[:top_n] return ranked_docs[:top_n]
def generate_answer(self, sys_prompt: str, user_prompt: str): def generate_answer(self, sys_prompt: str, user_prompt: str):
url = f"{self.ollama_url}/api/generate" url = f"{self.ollama_url}/api/generate"
@@ -87,7 +91,6 @@ class RagSystem:
"model": self.chat_model, "model": self.chat_model,
"system": sys_prompt, "system": sys_prompt,
"prompt": user_prompt, "prompt": user_prompt,
#"context": self.conversation_history,
"stream": False, "stream": False,
"options": { "options": {
"temperature": 0.5, "temperature": 0.5,
@@ -107,11 +110,10 @@ class RagSystem:
"model": self.chat_model, "model": self.chat_model,
"system": sys_prompt, "system": sys_prompt,
"prompt": user_prompt, "prompt": user_prompt,
#"context": self.conversation_history,
"stream": True, "stream": True,
"options": { "options": {
"temperature": 0.1, "temperature": 0.5,
"top_p": 0.2, # "top_p": 0.2,
}, },
} }
resp = requests.post(url, json=body, stream=True, timeout=900) resp = requests.post(url, json=body, stream=True, timeout=900)
@@ -119,6 +121,7 @@ class RagSystem:
raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}") raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}")
answer = "" answer = ""
self.response = None
for chunk in resp.iter_lines(): for chunk in resp.iter_lines():
if chunk: if chunk:
try: try:
@@ -139,27 +142,27 @@ class RagSystem:
answer += f" | Ошибка обработки чанка: {e}" answer += f" | Ошибка обработки чанка: {e}"
def get_prompt_eval_count(self): def get_prompt_eval_count(self):
if not self.response["prompt_eval_count"]: if not self.response:
return 0 return 0
return self.response["prompt_eval_count"] return self.response["prompt_eval_count"]
def get_prompt_eval_duration(self): def get_prompt_eval_duration(self):
if not self.response["prompt_eval_duration"]: if not self.response:
return 0 return 0
return self.response["prompt_eval_duration"] / (10 ** 9) return self.response["prompt_eval_duration"] / (10 ** 9)
def get_eval_count(self): def get_eval_count(self):
if not self.response["eval_count"]: if not self.response:
return 0 return 0
return self.response["eval_count"] return self.response["eval_count"]
def get_eval_duration(self): def get_eval_duration(self):
if not self.response["eval_duration"]: if not self.response:
return 0 return 0
return self.response["eval_duration"] / (10 ** 9) return self.response["eval_duration"] / (10 ** 9)
def get_total_duration(self): def get_total_duration(self):
if not self.response["total_duration"]: if not self.response:
return 0 return 0
return self.response["total_duration"] / (10 ** 9) return self.response["total_duration"] / (10 ** 9)
@@ -271,9 +274,9 @@ class App:
self.print_v(text=f"Найдено {len(context_docs)} документов") self.print_v(text=f"Найдено {len(context_docs)} документов")
return context_docs return context_docs
def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N): def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE):
self.print_v(text="\nРанжирование документов...") self.print_v(text="\nРанжирование документов...")
ranked_docs = self.rag.rank_documents(self.query, docs, top_n) ranked_docs = self.rag.rank_documents(self.query, docs, top_n, min_score)
self.print_v(text=f"После ранжирования осталось {len(ranked_docs)} документов") self.print_v(text=f"После ранжирования осталось {len(ranked_docs)} документов")
return ranked_docs return ranked_docs
@@ -403,10 +406,11 @@ Context:
else: else:
break break
ranked_docs = self.rank_docs(context_docs, self.args.topn) ranked_docs = self.rank_docs(context_docs, self.args.topn, self.args.min_rank_score)
if not ranked_docs: if not ranked_docs:
if args.interactive: if args.interactive:
print("<<< Релевантные документы были отсеяны полностью") print("<<< Документы были отсеяны полностью")
#TODO сделать ещё 2 попытки перезапроса+реранка других документов без учёта нерелевантных context_docs
self.query = None self.query = None
self.args.query = None self.args.query = None
continue continue
@@ -456,10 +460,6 @@ Context:
print("\n*** Завершение работы") print("\n*** Завершение работы")
break break
except Exception as e:
print(f"Ошибка: {e}")
break
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
@@ -480,6 +480,7 @@ if __name__ == "__main__":
parser.add_argument("--topk", type=int, default=DEFAULT_TOP_K, help="Количество документов для поиска") parser.add_argument("--topk", type=int, default=DEFAULT_TOP_K, help="Количество документов для поиска")
parser.add_argument("--use-rank", default=DEFAULT_USE_RANK, action=argparse.BooleanOptionalAction, help="Включить ранжирование") parser.add_argument("--use-rank", default=DEFAULT_USE_RANK, action=argparse.BooleanOptionalAction, help="Включить ранжирование")
parser.add_argument("--rank-model", type=str, default=DEFAULT_RANK_MODEL, help="Модель ранжирования") parser.add_argument("--rank-model", type=str, default=DEFAULT_RANK_MODEL, help="Модель ранжирования")
parser.add_argument("--min-rank-score", type=int, default=DEFAULT_MIN_RANK_SCORE, help="Минимальный ранк документа")
parser.add_argument("--topn", type=int, default=DEFAULT_TOP_N, help="Количество документов после ранжирования") parser.add_argument("--topn", type=int, default=DEFAULT_TOP_N, help="Количество документов после ранжирования")
args = parser.parse_args() args = parser.parse_args()