- переименован input_md => data - добавление инфы о дате, версии и авторе изменений conf-страницы в индекс - вывод этой инфы в источниках - вывод статистики последнего ответа - указание имени коллекции для qdrant - мелочи по текстовкам
382 lines
16 KiB
Python
382 lines
16 KiB
Python
import os
|
||
import requests
|
||
import json
|
||
import time
|
||
from sentence_transformers import SentenceTransformer
|
||
|
||
class RagSystem:
|
||
def __init__(self,
|
||
md_folder: str = "data",
|
||
ollama_url: str = "http://localhost:11434",
|
||
qdrant_host: str = "localhost",
|
||
qdrant_port: int = 6333,
|
||
embed_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
||
chat_model: str = "phi4-mini:3.8b"):
|
||
self.md_folder = md_folder
|
||
self.ollama_url = ollama_url
|
||
self.qdrant_host = qdrant_host
|
||
self.qdrant_port = qdrant_port
|
||
self.chat_model = chat_model
|
||
self.emb_model = SentenceTransformer(embed_model)
|
||
self.prompt = ""
|
||
self.conversation_history = []
|
||
self.load_chat_model()
|
||
|
||
def load_chat_model(self):
|
||
url = f"{self.ollama_url}/api/generate"
|
||
body = {"model": self.chat_model}
|
||
requests.post(url, json=body, timeout=600)
|
||
|
||
def search_qdrant(self, query: str, top_k: int = 6, qdrant_collection="rag"):
|
||
query_vec = self.emb_model.encode(query, show_progress_bar=False).tolist()
|
||
url = f"http://{self.qdrant_host}:{self.qdrant_port}/collections/{qdrant_collection}/points/search"
|
||
payload = {
|
||
"vector": query_vec,
|
||
"top": top_k,
|
||
"with_payload": True,
|
||
# "score_threshold": 0.6
|
||
}
|
||
resp = requests.post(url, json=payload)
|
||
if resp.status_code != 200:
|
||
raise RuntimeError(f"> Ошибка qdrant: {resp.status_code} {resp.text}")
|
||
results = resp.json().get("result", [])
|
||
return results
|
||
|
||
def prepare_sources(self, context_docs: list):
|
||
sources = ""
|
||
for idx, doc in enumerate(context_docs, start=1):
|
||
text = doc['payload'].get("text", "").strip()
|
||
sources = f"{sources}\n<source id=\"{idx}\">\n{text}\n</source>\n"
|
||
return sources
|
||
|
||
def prepare_prompt(self, query: str, context_docs: list):
|
||
sources = self.prepare_sources(context_docs)
|
||
if os.path.exists('sys_prompt.txt'):
|
||
with open('sys_prompt.txt', 'r') as fp:
|
||
prompt_template = fp.read()
|
||
return prompt_template.replace("{{sources}}", sources).replace("{{query}}", query)
|
||
else:
|
||
return f"""### Your role
|
||
You are a helpful assistant that can answer questions based on the provided sources.
|
||
|
||
### Your user
|
||
User is a human who is asking a question related to the provided sources.
|
||
|
||
### Your task
|
||
Please provide an answer based solely on the provided sources and the conversation history.
|
||
|
||
### Rules
|
||
- You **MUST** respond in the SAME language as the user's query.
|
||
- If uncertain, you **MUST** the user for clarification.
|
||
- If there are no sources in context, you **MUST** clearly state that.
|
||
- If none of the sources are helpful, you **MUST** clearly state that.
|
||
- If you are unsure about the answer, you **MUST** clearly state that.
|
||
- If the context is unreadable or of poor quality, you **MUST** inform the user and provide the best possible answer.
|
||
- When referencing information from a source, you **MUST** cite the appropriate source(s) using their corresponding numbers.
|
||
- **Only include inline citations using [id] (e.g., [1], [2]) when the <source> tag includes an id attribute.**
|
||
- You NEVER MUST NOT add <source> or any XML/HTML tags in your response.
|
||
- You NEVER MUST NOT cite if the <source> tag does not contain an id attribute.
|
||
- Every answer MAY include at least one source citation.
|
||
- Only cite a source when you are explicitly referencing it.
|
||
- You may also cite multiple sources if they are all relevant to the question.
|
||
- Ensure citations are concise and directly related to the information provided.
|
||
- You CAN format your responses using Markdown.
|
||
|
||
### Example of sources list:
|
||
|
||
```
|
||
<source id="1">The sky is red in the evening and blue in the morning.</source>
|
||
<source id="2">Water is wet when the sky is red.</source>
|
||
<query>When is water wet?</query>
|
||
```
|
||
Response:
|
||
```
|
||
Water will be wet when the sky is red [2], which occurs in the evening [1].
|
||
```
|
||
|
||
### Now let's start!
|
||
|
||
```
|
||
{sources}
|
||
<query>{query}</query>
|
||
```
|
||
|
||
Respond."""
|
||
|
||
def generate_answer(self, prompt: str):
|
||
url = f"{self.ollama_url}/api/generate"
|
||
body = {
|
||
"model": self.chat_model,
|
||
"prompt": prompt,
|
||
"messages": self.conversation_history,
|
||
"stream": False,
|
||
# "options": {
|
||
# "temperature": 0.4,
|
||
# "top_p": 0.1,
|
||
# },
|
||
}
|
||
self.response = requests.post(url, json=body, timeout=900)
|
||
if self.response.status_code != 200:
|
||
return f"Ошибка генерации ответа: {self.response.status_code} {self.response.text}"
|
||
return self.response.json().get("response", "").strip()
|
||
|
||
def generate_answer_stream(self, prompt: str):
|
||
url = f"{self.ollama_url}/api/generate"
|
||
body = {
|
||
"model": self.chat_model,
|
||
"prompt": prompt,
|
||
"messages": self.conversation_history,
|
||
"stream": True
|
||
}
|
||
resp = requests.post(url, json=body, stream=True, timeout=900)
|
||
if resp.status_code != 200:
|
||
raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}")
|
||
full_answer = ""
|
||
for chunk in resp.iter_lines():
|
||
if chunk:
|
||
try:
|
||
decoded_chunk = chunk.decode('utf-8')
|
||
data = json.loads(decoded_chunk)
|
||
if "response" in data:
|
||
yield data["response"]
|
||
full_answer += data["response"]
|
||
elif "error" in data:
|
||
print(f"Stream error: {data['error']}")
|
||
break
|
||
except json.JSONDecodeError:
|
||
print(f"Could not decode JSON from chunk: {chunk.decode('utf-8')}")
|
||
except Exception as e:
|
||
print(f"Error processing chunk: {e}")
|
||
|
||
def get_prompt_eval_count(self):
|
||
if not self.response:
|
||
return 0
|
||
return self.response.json().get("prompt_eval_count", 0)
|
||
|
||
def get_prompt_eval_duration(self):
|
||
if not self.response:
|
||
return 0
|
||
return self.response.json().get("prompt_eval_duration", 0) / (10 ** 9)
|
||
|
||
def get_eval_count(self):
|
||
if not self.response:
|
||
return 0
|
||
return self.response.json().get("eval_count", 0)
|
||
|
||
def get_eval_duration(self):
|
||
if not self.response:
|
||
return 0
|
||
return self.response.json().get("eval_duration", 0) / (10 ** 9)
|
||
|
||
def get_total_duration(self):
|
||
if not self.response:
|
||
return 0
|
||
return self.response.json().get("total_duration", 0) / (10 ** 9)
|
||
|
||
def get_tps(self):
|
||
eval_count = self.get_eval_count()
|
||
eval_duration = self.get_eval_duration()
|
||
if eval_count == 0 or eval_duration == 0:
|
||
return 0
|
||
return eval_count / eval_duration
|
||
|
||
def print_sources(context_docs: list):
|
||
print("\n\nИсточники:")
|
||
for idx, doc in enumerate(context_docs, start=1):
|
||
title = doc['payload'].get("filename", None)
|
||
url = doc['payload'].get("url", None)
|
||
date = doc['payload'].get("date", None)
|
||
version = doc['payload'].get("version", None)
|
||
author = doc['payload'].get("author", None)
|
||
|
||
if url is None:
|
||
url = "(нет веб-ссылки)"
|
||
if date is None:
|
||
date = "(неизвестно)"
|
||
if version is None:
|
||
version = "0"
|
||
if author is None:
|
||
author = "(неизвестен)"
|
||
|
||
print(f"{idx}. {title}")
|
||
print(f" {url} (v{version} {author})")
|
||
print(f" актуальность на {date}")
|
||
|
||
def print_v(text: str, is_verbose: bool):
|
||
if is_verbose:
|
||
print(text)
|
||
|
||
def print_stats(rag: RagSystem):
|
||
print("\n\nСтатистика:")
|
||
print(f"* Time: {rag.get_total_duration()}s")
|
||
print(f"* TPS: {rag.get_tps()}")
|
||
print(f"* PEC: {rag.get_prompt_eval_count()}")
|
||
print(f"* PED: {rag.get_prompt_eval_duration()}s")
|
||
print(f"* EC: {rag.get_eval_count()}")
|
||
print(f"* ED: {rag.get_eval_duration()}s\n")
|
||
|
||
def main():
|
||
import sys
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="RAG-система с использованием Ollama и Qdrant")
|
||
parser.add_argument("--query", type=str, help="Запрос к RAG")
|
||
parser.add_argument("--interactive", default=False, action=argparse.BooleanOptionalAction, help="Перейти в интерактивный режим диалога")
|
||
parser.add_argument("--show-prompt", default=False, action=argparse.BooleanOptionalAction, help="Показать полный промпт перед обработкой запроса")
|
||
parser.add_argument("--qdrant-host", default="localhost", help="Qdrant host")
|
||
parser.add_argument("--qdrant-port", type=int, default=6333, help="Qdrant port")
|
||
parser.add_argument("--qdrant-collection", type=str, default="rag", help="Название коллекции для поиска документов")
|
||
parser.add_argument("--ollama-url", default="http://localhost:11434", help="Ollama API URL")
|
||
parser.add_argument("--emb-model", default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", help="Модель эмбеддинга")
|
||
parser.add_argument("--chat-model", default="phi4-mini:3.8b", help="Модель генерации Ollama")
|
||
parser.add_argument("--topk", type=int, default=6, help="Количество документов для поиска")
|
||
parser.add_argument("--verbose", default=False, action=argparse.BooleanOptionalAction, help="Выводить промежуточные служебные сообщения")
|
||
parser.add_argument("--show-stats", default=False, action=argparse.BooleanOptionalAction, help="Выводить статистику об ответе (не работает с --stream)")
|
||
parser.add_argument("--stream", default=False, action=argparse.BooleanOptionalAction, help="Выводить статистику об ответе")
|
||
args = parser.parse_args()
|
||
|
||
if not args.query and not args.interactive:
|
||
print("Ошибка: укажите запрос (--query) и/или используйте интерактивный режим (--interactive)")
|
||
sys.exit(1)
|
||
|
||
print_v(f"Адрес ollama: {args.ollama_url}", args.verbose)
|
||
print_v(f"Адрес qdrant: {args.qdrant_host}:{args.qdrant_port}", args.verbose)
|
||
print_v(f"Модель эмбеддинга: {args.emb_model}", args.verbose)
|
||
print_v(f"Модель чата: {args.chat_model}", args.verbose)
|
||
print_v(f"Документов для поиска: {args.topk}", args.verbose)
|
||
print_v(f"Коллекция для поиска: {args.qdrant_collection}", args.verbose)
|
||
if os.path.exists('sys_prompt.txt'):
|
||
print_v("Будет использоваться sys_prompt.txt!", args.verbose)
|
||
|
||
print_v("\nПервая инициализация моделей...", args.verbose)
|
||
rag = RagSystem(
|
||
ollama_url=args.ollama_url,
|
||
qdrant_host=args.qdrant_host,
|
||
qdrant_port=args.qdrant_port,
|
||
embed_model=args.emb_model,
|
||
chat_model=args.chat_model
|
||
)
|
||
print_v(f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG", args.verbose)
|
||
|
||
query = None
|
||
if args.interactive:
|
||
print_v("\nИНТЕРАКТИВНЫЙ РЕЖИМ", args.verbose)
|
||
print_v("Можете вводить запрос (или 'exit' для выхода)\n", args.verbose)
|
||
|
||
if args.query:
|
||
query = args.query.strip()
|
||
print(f">>> {query}")
|
||
|
||
while True:
|
||
try:
|
||
if not query or query == "":
|
||
query = input(">>> ").strip()
|
||
|
||
if not query or query == "":
|
||
continue
|
||
|
||
if query.lower() == "help":
|
||
print("<<< Команды итерактивного режима:")
|
||
print("save -- сохранить диалог в файл")
|
||
print("stats -- вывести статистику последнего ответа")
|
||
print("exit -- выход\n")
|
||
query = None
|
||
continue
|
||
|
||
if query.strip().lower() == "save":
|
||
import datetime
|
||
timestamp = int(time.time())
|
||
dt = datetime.datetime.fromtimestamp(timestamp).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||
filename = f"chats/chat-{timestamp}.md"
|
||
|
||
markdown_content = f"# История диалога от {dt}\n\n"
|
||
markdown_content += f"## Параметры диалога\n"
|
||
markdown_content += f"```\nargs = {args}\n```\n"
|
||
markdown_content += f"```\nemb_model = {rag.emb_model}\n```\n"
|
||
|
||
for entry in rag.conversation_history:
|
||
if entry['role'] == 'user':
|
||
markdown_content += f"## Пользователь\n\n"
|
||
elif entry['role'] == 'assistant':
|
||
markdown_content += f"## Модель\n\n"
|
||
docs = rag.prepare_sources(entry['docs']).replace("```", "")
|
||
markdown_content += f"```\n{docs}\n```\n\n"
|
||
markdown_content += f"{entry['content']}\n\n"
|
||
|
||
os.makedirs('chats', exist_ok=True)
|
||
with open(filename, 'w') as fp:
|
||
fp.write(markdown_content)
|
||
|
||
print(f"<<< Диалог сохранён в файл: {filename}\n")
|
||
query = None
|
||
continue
|
||
|
||
if query.strip().lower() == "exit":
|
||
print_v("\n*** Завершение работы", args.verbose)
|
||
break
|
||
|
||
print_v("\nПоиск релевантных документов...", args.verbose)
|
||
context_docs = rag.search_qdrant(query, top_k=args.topk, qdrant_collection=args.qdrant_collection)
|
||
if not context_docs:
|
||
print("<<< Релевантные документы не найдены")
|
||
if args.interactive:
|
||
query = None
|
||
continue
|
||
else:
|
||
break
|
||
|
||
print_v(f"Найдено {len(context_docs)} релевантных документов", args.verbose)
|
||
# print_sources(context_docs)
|
||
|
||
prompt = rag.prepare_prompt(query=query, context_docs=context_docs)
|
||
if args.show_prompt:
|
||
print("\nПолный системный промпт: --------------------------")
|
||
print(f"{prompt}\n---------------------------------------------------")
|
||
|
||
print_v("\nГенерация ответа...\n", args.verbose)
|
||
|
||
if args.stream:
|
||
answer = "\n<<< "
|
||
print(answer, end='', flush=True)
|
||
try:
|
||
for message_part in rag.generate_answer_stream(prompt):
|
||
answer += message_part
|
||
print(message_part, end='', flush=True)
|
||
except RuntimeError as e:
|
||
answer = str(e)
|
||
print(f"\n{answer}\n===================================================\n")
|
||
else:
|
||
answer = rag.generate_answer(prompt)
|
||
print(f"<<< {answer}\n")
|
||
|
||
print_sources(context_docs)
|
||
if args.show_stats and not args.stream:
|
||
print_stats(rag)
|
||
|
||
rag.conversation_history.append({
|
||
"role": "user",
|
||
"content": query,
|
||
})
|
||
|
||
rag.conversation_history.append({
|
||
"role": "assistant",
|
||
"docs": context_docs,
|
||
"content": answer,
|
||
})
|
||
|
||
if args.interactive:
|
||
query = None
|
||
else:
|
||
break
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n*** Завершение работы")
|
||
break
|
||
|
||
except Exception as e:
|
||
print(f"Ошибка: {e}")
|
||
break
|
||
|
||
if __name__ == "__main__":
|
||
main()
|