Проба отсева документов по минимальному скору
This commit is contained in:
@@ -219,8 +219,8 @@ python3 rag.py --help
|
|||||||
|
|
||||||
### Ранжирование
|
### Ранжирование
|
||||||
|
|
||||||
- [`cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1`](https://hf.co/cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1) ☑️
|
- [`cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1`](https://hf.co/cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1)
|
||||||
- [`cross-encoder/ms-marco-MiniLM-L-6-v2`](https://hf.co/cross-encoder/ms-marco-MiniLM-L-6-v2)
|
- [`cross-encoder/ms-marco-MiniLM-L-6-v2`](https://hf.co/cross-encoder/ms-marco-MiniLM-L-6-v2) ☑️
|
||||||
- [`cross-encoder/ms-marco-TinyBERT-L-2-v2`](https://hf.co/cross-encoder/ms-marco-TinyBERT-L-2-v2)
|
- [`cross-encoder/ms-marco-TinyBERT-L-2-v2`](https://hf.co/cross-encoder/ms-marco-TinyBERT-L-2-v2)
|
||||||
- ...
|
- ...
|
||||||
|
|
||||||
|
|||||||
51
rag/rag.py
51
rag/rag.py
@@ -6,10 +6,10 @@ import sys
|
|||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||||
|
|
||||||
DEFAULT_CHAT_MODEL = "phi4-mini:3.8b"
|
DEFAULT_CHAT_MODEL = "openchat:7b"
|
||||||
DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||||
DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
# DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
||||||
# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
|
# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
|
||||||
DEFAULT_MD_FOLDER = "data"
|
DEFAULT_MD_FOLDER = "data"
|
||||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||||
@@ -24,6 +24,7 @@ DEFAULT_SHOW_STATS = False
|
|||||||
DEFAULT_STREAM = False
|
DEFAULT_STREAM = False
|
||||||
DEFAULT_INTERACTIVE = False
|
DEFAULT_INTERACTIVE = False
|
||||||
DEFAULT_SHOW_PROMPT = False
|
DEFAULT_SHOW_PROMPT = False
|
||||||
|
DEFAULT_MIN_RANK_SCORE = 0
|
||||||
|
|
||||||
class RagSystem:
|
class RagSystem:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -44,7 +45,6 @@ class RagSystem:
|
|||||||
if self.use_rank:
|
if self.use_rank:
|
||||||
self.rank_model = CrossEncoder(rank_model)
|
self.rank_model = CrossEncoder(rank_model)
|
||||||
self.conversation_history = []
|
self.conversation_history = []
|
||||||
|
|
||||||
self.load_chat_model()
|
self.load_chat_model()
|
||||||
|
|
||||||
def load_chat_model(self):
|
def load_chat_model(self):
|
||||||
@@ -68,18 +68,22 @@ class RagSystem:
|
|||||||
})
|
})
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N):
|
def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE):
|
||||||
if not self.use_rank:
|
if not self.use_rank:
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
pairs = [[query, doc["payload"]["text"]] for doc in documents]
|
pairs = [[query, doc["payload"]["text"]] for doc in documents]
|
||||||
scores = self.rank_model.predict(pairs)
|
scores = self.rank_model.predict(pairs)
|
||||||
|
|
||||||
|
ranked_docs = []
|
||||||
for i, doc in enumerate(documents):
|
for i, doc in enumerate(documents):
|
||||||
doc["rank_score"] = float(scores[i])
|
score = float(scores[i])
|
||||||
|
doc["rank_score"] = score
|
||||||
|
if score >= min_score:
|
||||||
|
ranked_docs.append(doc)
|
||||||
|
|
||||||
documents.sort(key=lambda x: x['rank_score'], reverse=True)
|
ranked_docs.sort(key=lambda x: x['rank_score'], reverse=True)
|
||||||
return documents[:top_n]
|
return ranked_docs[:top_n]
|
||||||
|
|
||||||
def generate_answer(self, sys_prompt: str, user_prompt: str):
|
def generate_answer(self, sys_prompt: str, user_prompt: str):
|
||||||
url = f"{self.ollama_url}/api/generate"
|
url = f"{self.ollama_url}/api/generate"
|
||||||
@@ -87,7 +91,6 @@ class RagSystem:
|
|||||||
"model": self.chat_model,
|
"model": self.chat_model,
|
||||||
"system": sys_prompt,
|
"system": sys_prompt,
|
||||||
"prompt": user_prompt,
|
"prompt": user_prompt,
|
||||||
#"context": self.conversation_history,
|
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"options": {
|
"options": {
|
||||||
"temperature": 0.5,
|
"temperature": 0.5,
|
||||||
@@ -107,11 +110,10 @@ class RagSystem:
|
|||||||
"model": self.chat_model,
|
"model": self.chat_model,
|
||||||
"system": sys_prompt,
|
"system": sys_prompt,
|
||||||
"prompt": user_prompt,
|
"prompt": user_prompt,
|
||||||
#"context": self.conversation_history,
|
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"options": {
|
"options": {
|
||||||
"temperature": 0.1,
|
"temperature": 0.5,
|
||||||
"top_p": 0.2,
|
# "top_p": 0.2,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
resp = requests.post(url, json=body, stream=True, timeout=900)
|
resp = requests.post(url, json=body, stream=True, timeout=900)
|
||||||
@@ -119,6 +121,7 @@ class RagSystem:
|
|||||||
raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}")
|
raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}")
|
||||||
|
|
||||||
answer = ""
|
answer = ""
|
||||||
|
self.response = None
|
||||||
for chunk in resp.iter_lines():
|
for chunk in resp.iter_lines():
|
||||||
if chunk:
|
if chunk:
|
||||||
try:
|
try:
|
||||||
@@ -139,27 +142,27 @@ class RagSystem:
|
|||||||
answer += f" | Ошибка обработки чанка: {e}"
|
answer += f" | Ошибка обработки чанка: {e}"
|
||||||
|
|
||||||
def get_prompt_eval_count(self):
|
def get_prompt_eval_count(self):
|
||||||
if not self.response["prompt_eval_count"]:
|
if not self.response:
|
||||||
return 0
|
return 0
|
||||||
return self.response["prompt_eval_count"]
|
return self.response["prompt_eval_count"]
|
||||||
|
|
||||||
def get_prompt_eval_duration(self):
|
def get_prompt_eval_duration(self):
|
||||||
if not self.response["prompt_eval_duration"]:
|
if not self.response:
|
||||||
return 0
|
return 0
|
||||||
return self.response["prompt_eval_duration"] / (10 ** 9)
|
return self.response["prompt_eval_duration"] / (10 ** 9)
|
||||||
|
|
||||||
def get_eval_count(self):
|
def get_eval_count(self):
|
||||||
if not self.response["eval_count"]:
|
if not self.response:
|
||||||
return 0
|
return 0
|
||||||
return self.response["eval_count"]
|
return self.response["eval_count"]
|
||||||
|
|
||||||
def get_eval_duration(self):
|
def get_eval_duration(self):
|
||||||
if not self.response["eval_duration"]:
|
if not self.response:
|
||||||
return 0
|
return 0
|
||||||
return self.response["eval_duration"] / (10 ** 9)
|
return self.response["eval_duration"] / (10 ** 9)
|
||||||
|
|
||||||
def get_total_duration(self):
|
def get_total_duration(self):
|
||||||
if not self.response["total_duration"]:
|
if not self.response:
|
||||||
return 0
|
return 0
|
||||||
return self.response["total_duration"] / (10 ** 9)
|
return self.response["total_duration"] / (10 ** 9)
|
||||||
|
|
||||||
@@ -271,9 +274,9 @@ class App:
|
|||||||
self.print_v(text=f"Найдено {len(context_docs)} документов")
|
self.print_v(text=f"Найдено {len(context_docs)} документов")
|
||||||
return context_docs
|
return context_docs
|
||||||
|
|
||||||
def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N):
|
def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N, min_score: int = DEFAULT_MIN_RANK_SCORE):
|
||||||
self.print_v(text="\nРанжирование документов...")
|
self.print_v(text="\nРанжирование документов...")
|
||||||
ranked_docs = self.rag.rank_documents(self.query, docs, top_n)
|
ranked_docs = self.rag.rank_documents(self.query, docs, top_n, min_score)
|
||||||
self.print_v(text=f"После ранжирования осталось {len(ranked_docs)} документов")
|
self.print_v(text=f"После ранжирования осталось {len(ranked_docs)} документов")
|
||||||
return ranked_docs
|
return ranked_docs
|
||||||
|
|
||||||
@@ -403,10 +406,11 @@ Context:
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
ranked_docs = self.rank_docs(context_docs, self.args.topn)
|
ranked_docs = self.rank_docs(context_docs, self.args.topn, self.args.min_rank_score)
|
||||||
if not ranked_docs:
|
if not ranked_docs:
|
||||||
if args.interactive:
|
if args.interactive:
|
||||||
print("<<< Релевантные документы были отсеяны полностью")
|
print("<<< Документы были отсеяны полностью")
|
||||||
|
#TODO сделать ещё 2 попытки перезапроса+реранка других документов без учёта нерелевантных context_docs
|
||||||
self.query = None
|
self.query = None
|
||||||
self.args.query = None
|
self.args.query = None
|
||||||
continue
|
continue
|
||||||
@@ -456,10 +460,6 @@ Context:
|
|||||||
print("\n*** Завершение работы")
|
print("\n*** Завершение работы")
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Ошибка: {e}")
|
|
||||||
break
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
@@ -480,6 +480,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--topk", type=int, default=DEFAULT_TOP_K, help="Количество документов для поиска")
|
parser.add_argument("--topk", type=int, default=DEFAULT_TOP_K, help="Количество документов для поиска")
|
||||||
parser.add_argument("--use-rank", default=DEFAULT_USE_RANK, action=argparse.BooleanOptionalAction, help="Включить ранжирование")
|
parser.add_argument("--use-rank", default=DEFAULT_USE_RANK, action=argparse.BooleanOptionalAction, help="Включить ранжирование")
|
||||||
parser.add_argument("--rank-model", type=str, default=DEFAULT_RANK_MODEL, help="Модель ранжирования")
|
parser.add_argument("--rank-model", type=str, default=DEFAULT_RANK_MODEL, help="Модель ранжирования")
|
||||||
|
parser.add_argument("--min-rank-score", type=int, default=DEFAULT_MIN_RANK_SCORE, help="Минимальный ранк документа")
|
||||||
parser.add_argument("--topn", type=int, default=DEFAULT_TOP_N, help="Количество документов после ранжирования")
|
parser.add_argument("--topn", type=int, default=DEFAULT_TOP_N, help="Количество документов после ранжирования")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user