1
0

Проба отсева документов по минимальному скору

This commit is contained in:
2025-09-08 09:15:38 +08:00
parent 0106d157d3
commit 1f54ab0409
2 changed files with 28 additions and 27 deletions

View File

@@ -219,8 +219,8 @@ python3 rag.py --help
### Ранжирование
- [`cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1`](https://hf.co/cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1) ☑️
- [`cross-encoder/ms-marco-MiniLM-L-6-v2`](https://hf.co/cross-encoder/ms-marco-MiniLM-L-6-v2)
- [`cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1`](https://hf.co/cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1)
- [`cross-encoder/ms-marco-MiniLM-L-6-v2`](https://hf.co/cross-encoder/ms-marco-MiniLM-L-6-v2) ☑️
- [`cross-encoder/ms-marco-TinyBERT-L-2-v2`](https://hf.co/cross-encoder/ms-marco-TinyBERT-L-2-v2)
- ...

View File

@@ -6,10 +6,10 @@ import sys
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer, CrossEncoder
DEFAULT_CHAT_MODEL = "phi4-mini:3.8b"
DEFAULT_CHAT_MODEL = "openchat:7b"
DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
# DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
DEFAULT_MD_FOLDER = "data"
DEFAULT_OLLAMA_URL = "http://localhost:11434"
@@ -24,6 +24,7 @@ DEFAULT_SHOW_STATS = False
DEFAULT_STREAM = False
DEFAULT_INTERACTIVE = False
DEFAULT_SHOW_PROMPT = False
DEFAULT_MIN_RANK_SCORE = 0
class RagSystem:
def __init__(self,
@@ -44,7 +45,6 @@ class RagSystem:
if self.use_rank:
self.rank_model = CrossEncoder(rank_model)
self.conversation_history = []
self.load_chat_model()
def load_chat_model(self):
@@ -68,18 +68,22 @@ class RagSystem:
})
return docs
def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N):
def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE):
if not self.use_rank:
return documents
pairs = [[query, doc["payload"]["text"]] for doc in documents]
scores = self.rank_model.predict(pairs)
ranked_docs = []
for i, doc in enumerate(documents):
doc["rank_score"] = float(scores[i])
score = float(scores[i])
doc["rank_score"] = score
if score >= min_score:
ranked_docs.append(doc)
documents.sort(key=lambda x: x['rank_score'], reverse=True)
return documents[:top_n]
ranked_docs.sort(key=lambda x: x['rank_score'], reverse=True)
return ranked_docs[:top_n]
def generate_answer(self, sys_prompt: str, user_prompt: str):
url = f"{self.ollama_url}/api/generate"
@@ -87,7 +91,6 @@ class RagSystem:
"model": self.chat_model,
"system": sys_prompt,
"prompt": user_prompt,
#"context": self.conversation_history,
"stream": False,
"options": {
"temperature": 0.5,
@@ -107,11 +110,10 @@ class RagSystem:
"model": self.chat_model,
"system": sys_prompt,
"prompt": user_prompt,
#"context": self.conversation_history,
"stream": True,
"options": {
"temperature": 0.1,
"top_p": 0.2,
"temperature": 0.5,
# "top_p": 0.2,
},
}
resp = requests.post(url, json=body, stream=True, timeout=900)
@@ -119,6 +121,7 @@ class RagSystem:
raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}")
answer = ""
self.response = None
for chunk in resp.iter_lines():
if chunk:
try:
@@ -139,27 +142,27 @@ class RagSystem:
answer += f" | Ошибка обработки чанка: {e}"
def get_prompt_eval_count(self):
if not self.response["prompt_eval_count"]:
if not self.response:
return 0
return self.response["prompt_eval_count"]
def get_prompt_eval_duration(self):
if not self.response["prompt_eval_duration"]:
if not self.response:
return 0
return self.response["prompt_eval_duration"] / (10 ** 9)
def get_eval_count(self):
if not self.response["eval_count"]:
if not self.response:
return 0
return self.response["eval_count"]
def get_eval_duration(self):
if not self.response["eval_duration"]:
if not self.response:
return 0
return self.response["eval_duration"] / (10 ** 9)
def get_total_duration(self):
if not self.response["total_duration"]:
if not self.response:
return 0
return self.response["total_duration"] / (10 ** 9)
@@ -271,9 +274,9 @@ class App:
self.print_v(text=f"Найдено {len(context_docs)} документов")
return context_docs
def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N):
def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE):
self.print_v(text="\nРанжирование документов...")
ranked_docs = self.rag.rank_documents(self.query, docs, top_n)
ranked_docs = self.rag.rank_documents(self.query, docs, top_n, min_score)
self.print_v(text=f"После ранжирования осталось {len(ranked_docs)} документов")
return ranked_docs
@@ -403,10 +406,11 @@ Context:
else:
break
ranked_docs = self.rank_docs(context_docs, self.args.topn)
ranked_docs = self.rank_docs(context_docs, self.args.topn, self.args.min_rank_score)
if not ranked_docs:
if args.interactive:
print("<<< Релевантные документы были отсеяны полностью")
print("<<< Документы были отсеяны полностью")
#TODO сделать ещё 2 попытки перезапроса+реранка других документов без учёта нерелевантных context_docs
self.query = None
self.args.query = None
continue
@@ -456,10 +460,6 @@ Context:
print("\n*** Завершение работы")
break
except Exception as e:
print(f"Ошибка: {e}")
break
if __name__ == "__main__":
import argparse
@@ -480,6 +480,7 @@ if __name__ == "__main__":
parser.add_argument("--topk", type=int, default=DEFAULT_TOP_K, help="Количество документов для поиска")
parser.add_argument("--use-rank", default=DEFAULT_USE_RANK, action=argparse.BooleanOptionalAction, help="Включить ранжирование")
parser.add_argument("--rank-model", type=str, default=DEFAULT_RANK_MODEL, help="Модель ранжирования")
parser.add_argument("--min-rank-score", type=int, default=DEFAULT_MIN_RANK_SCORE, help="Минимальный ранк документа")
parser.add_argument("--topn", type=int, default=DEFAULT_TOP_N, help="Количество документов после ранжирования")
args = parser.parse_args()