Проба отсева документов по минимальному скору
This commit is contained in:
@@ -219,8 +219,8 @@ python3 rag.py --help
|
||||
|
||||
### Ранжирование
|
||||
|
||||
- [`cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1`](https://hf.co/cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1) ☑️
|
||||
- [`cross-encoder/ms-marco-MiniLM-L-6-v2`](https://hf.co/cross-encoder/ms-marco-MiniLM-L-6-v2)
|
||||
- [`cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1`](https://hf.co/cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1)
|
||||
- [`cross-encoder/ms-marco-MiniLM-L-6-v2`](https://hf.co/cross-encoder/ms-marco-MiniLM-L-6-v2) ☑️
|
||||
- [`cross-encoder/ms-marco-TinyBERT-L-2-v2`](https://hf.co/cross-encoder/ms-marco-TinyBERT-L-2-v2)
|
||||
- ...
|
||||
|
||||
|
||||
51
rag/rag.py
51
rag/rag.py
@@ -6,10 +6,10 @@ import sys
|
||||
from qdrant_client import QdrantClient
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
|
||||
DEFAULT_CHAT_MODEL = "phi4-mini:3.8b"
|
||||
DEFAULT_CHAT_MODEL = "openchat:7b"
|
||||
DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
||||
# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
# DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
||||
DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
|
||||
DEFAULT_MD_FOLDER = "data"
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
@@ -24,6 +24,7 @@ DEFAULT_SHOW_STATS = False
|
||||
DEFAULT_STREAM = False
|
||||
DEFAULT_INTERACTIVE = False
|
||||
DEFAULT_SHOW_PROMPT = False
|
||||
DEFAULT_MIN_RANK_SCORE = 0
|
||||
|
||||
class RagSystem:
|
||||
def __init__(self,
|
||||
@@ -44,7 +45,6 @@ class RagSystem:
|
||||
if self.use_rank:
|
||||
self.rank_model = CrossEncoder(rank_model)
|
||||
self.conversation_history = []
|
||||
|
||||
self.load_chat_model()
|
||||
|
||||
def load_chat_model(self):
|
||||
@@ -68,18 +68,22 @@ class RagSystem:
|
||||
})
|
||||
return docs
|
||||
|
||||
def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N):
|
||||
def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE):
|
||||
if not self.use_rank:
|
||||
return documents
|
||||
|
||||
pairs = [[query, doc["payload"]["text"]] for doc in documents]
|
||||
scores = self.rank_model.predict(pairs)
|
||||
|
||||
ranked_docs = []
|
||||
for i, doc in enumerate(documents):
|
||||
doc["rank_score"] = float(scores[i])
|
||||
score = float(scores[i])
|
||||
doc["rank_score"] = score
|
||||
if score >= min_score:
|
||||
ranked_docs.append(doc)
|
||||
|
||||
documents.sort(key=lambda x: x['rank_score'], reverse=True)
|
||||
return documents[:top_n]
|
||||
ranked_docs.sort(key=lambda x: x['rank_score'], reverse=True)
|
||||
return ranked_docs[:top_n]
|
||||
|
||||
def generate_answer(self, sys_prompt: str, user_prompt: str):
|
||||
url = f"{self.ollama_url}/api/generate"
|
||||
@@ -87,7 +91,6 @@ class RagSystem:
|
||||
"model": self.chat_model,
|
||||
"system": sys_prompt,
|
||||
"prompt": user_prompt,
|
||||
#"context": self.conversation_history,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": 0.5,
|
||||
@@ -107,11 +110,10 @@ class RagSystem:
|
||||
"model": self.chat_model,
|
||||
"system": sys_prompt,
|
||||
"prompt": user_prompt,
|
||||
#"context": self.conversation_history,
|
||||
"stream": True,
|
||||
"options": {
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.2,
|
||||
"temperature": 0.5,
|
||||
# "top_p": 0.2,
|
||||
},
|
||||
}
|
||||
resp = requests.post(url, json=body, stream=True, timeout=900)
|
||||
@@ -119,6 +121,7 @@ class RagSystem:
|
||||
raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}")
|
||||
|
||||
answer = ""
|
||||
self.response = None
|
||||
for chunk in resp.iter_lines():
|
||||
if chunk:
|
||||
try:
|
||||
@@ -139,27 +142,27 @@ class RagSystem:
|
||||
answer += f" | Ошибка обработки чанка: {e}"
|
||||
|
||||
def get_prompt_eval_count(self):
|
||||
if not self.response["prompt_eval_count"]:
|
||||
if not self.response:
|
||||
return 0
|
||||
return self.response["prompt_eval_count"]
|
||||
|
||||
def get_prompt_eval_duration(self):
|
||||
if not self.response["prompt_eval_duration"]:
|
||||
if not self.response:
|
||||
return 0
|
||||
return self.response["prompt_eval_duration"] / (10 ** 9)
|
||||
|
||||
def get_eval_count(self):
|
||||
if not self.response["eval_count"]:
|
||||
if not self.response:
|
||||
return 0
|
||||
return self.response["eval_count"]
|
||||
|
||||
def get_eval_duration(self):
|
||||
if not self.response["eval_duration"]:
|
||||
if not self.response:
|
||||
return 0
|
||||
return self.response["eval_duration"] / (10 ** 9)
|
||||
|
||||
def get_total_duration(self):
|
||||
if not self.response["total_duration"]:
|
||||
if not self.response:
|
||||
return 0
|
||||
return self.response["total_duration"] / (10 ** 9)
|
||||
|
||||
@@ -271,9 +274,9 @@ class App:
|
||||
self.print_v(text=f"Найдено {len(context_docs)} документов")
|
||||
return context_docs
|
||||
|
||||
def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N):
|
||||
def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE):
|
||||
self.print_v(text="\nРанжирование документов...")
|
||||
ranked_docs = self.rag.rank_documents(self.query, docs, top_n)
|
||||
ranked_docs = self.rag.rank_documents(self.query, docs, top_n, min_score)
|
||||
self.print_v(text=f"После ранжирования осталось {len(ranked_docs)} документов")
|
||||
return ranked_docs
|
||||
|
||||
@@ -403,10 +406,11 @@ Context:
|
||||
else:
|
||||
break
|
||||
|
||||
ranked_docs = self.rank_docs(context_docs, self.args.topn)
|
||||
ranked_docs = self.rank_docs(context_docs, self.args.topn, self.args.min_rank_score)
|
||||
if not ranked_docs:
|
||||
if args.interactive:
|
||||
print("<<< Релевантные документы были отсеяны полностью")
|
||||
print("<<< Документы были отсеяны полностью")
|
||||
#TODO сделать ещё 2 попытки перезапроса+реранка других документов без учёта нерелевантных context_docs
|
||||
self.query = None
|
||||
self.args.query = None
|
||||
continue
|
||||
@@ -456,10 +460,6 @@ Context:
|
||||
print("\n*** Завершение работы")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"Ошибка: {e}")
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
@@ -480,6 +480,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--topk", type=int, default=DEFAULT_TOP_K, help="Количество документов для поиска")
|
||||
parser.add_argument("--use-rank", default=DEFAULT_USE_RANK, action=argparse.BooleanOptionalAction, help="Включить ранжирование")
|
||||
parser.add_argument("--rank-model", type=str, default=DEFAULT_RANK_MODEL, help="Модель ранжирования")
|
||||
parser.add_argument("--min-rank-score", type=int, default=DEFAULT_MIN_RANK_SCORE, help="Минимальный ранк документа")
|
||||
parser.add_argument("--topn", type=int, default=DEFAULT_TOP_N, help="Количество документов после ранжирования")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user