1
0
Files
ollama/rag/rag.py
AnthonyAxenov f3672e6ffd Много мелких доработок
- переименован input_md => data
- добавление инфы о дате, версии и авторе изменений conf-страницы в индекс
- вывод этой инфы в источниках
- вывод статистики последнего ответа
- указание имени коллекции для qdrant
- мелочи по текстовкам
2025-08-29 08:54:43 +08:00

382 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import requests
import json
import time
from sentence_transformers import SentenceTransformer
class RagSystem:
def __init__(self,
md_folder: str = "data",
ollama_url: str = "http://localhost:11434",
qdrant_host: str = "localhost",
qdrant_port: int = 6333,
embed_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
chat_model: str = "phi4-mini:3.8b"):
self.md_folder = md_folder
self.ollama_url = ollama_url
self.qdrant_host = qdrant_host
self.qdrant_port = qdrant_port
self.chat_model = chat_model
self.emb_model = SentenceTransformer(embed_model)
self.prompt = ""
self.conversation_history = []
self.load_chat_model()
def load_chat_model(self):
url = f"{self.ollama_url}/api/generate"
body = {"model": self.chat_model}
requests.post(url, json=body, timeout=600)
def search_qdrant(self, query: str, top_k: int = 6, qdrant_collection="rag"):
query_vec = self.emb_model.encode(query, show_progress_bar=False).tolist()
url = f"http://{self.qdrant_host}:{self.qdrant_port}/collections/{qdrant_collection}/points/search"
payload = {
"vector": query_vec,
"top": top_k,
"with_payload": True,
# "score_threshold": 0.6
}
resp = requests.post(url, json=payload)
if resp.status_code != 200:
raise RuntimeError(f"> Ошибка qdrant: {resp.status_code} {resp.text}")
results = resp.json().get("result", [])
return results
def prepare_sources(self, context_docs: list):
sources = ""
for idx, doc in enumerate(context_docs, start=1):
text = doc['payload'].get("text", "").strip()
sources = f"{sources}\n<source id=\"{idx}\">\n{text}\n</source>\n"
return sources
def prepare_prompt(self, query: str, context_docs: list):
sources = self.prepare_sources(context_docs)
if os.path.exists('sys_prompt.txt'):
with open('sys_prompt.txt', 'r') as fp:
prompt_template = fp.read()
return prompt_template.replace("{{sources}}", sources).replace("{{query}}", query)
else:
return f"""### Your role
You are a helpful assistant that can answer questions based on the provided sources.
### Your user
User is a human who is asking a question related to the provided sources.
### Your task
Please provide an answer based solely on the provided sources and the conversation history.
### Rules
- You **MUST** respond in the SAME language as the user's query.
- If uncertain, you **MUST** the user for clarification.
- If there are no sources in context, you **MUST** clearly state that.
- If none of the sources are helpful, you **MUST** clearly state that.
- If you are unsure about the answer, you **MUST** clearly state that.
- If the context is unreadable or of poor quality, you **MUST** inform the user and provide the best possible answer.
- When referencing information from a source, you **MUST** cite the appropriate source(s) using their corresponding numbers.
- **Only include inline citations using [id] (e.g., [1], [2]) when the <source> tag includes an id attribute.**
- You NEVER MUST NOT add <source> or any XML/HTML tags in your response.
- You NEVER MUST NOT cite if the <source> tag does not contain an id attribute.
- Every answer MAY include at least one source citation.
- Only cite a source when you are explicitly referencing it.
- You may also cite multiple sources if they are all relevant to the question.
- Ensure citations are concise and directly related to the information provided.
- You CAN format your responses using Markdown.
### Example of sources list:
```
<source id="1">The sky is red in the evening and blue in the morning.</source>
<source id="2">Water is wet when the sky is red.</source>
<query>When is water wet?</query>
```
Response:
```
Water will be wet when the sky is red [2], which occurs in the evening [1].
```
### Now let's start!
```
{sources}
<query>{query}</query>
```
Respond."""
def generate_answer(self, prompt: str):
url = f"{self.ollama_url}/api/generate"
body = {
"model": self.chat_model,
"prompt": prompt,
"messages": self.conversation_history,
"stream": False,
# "options": {
# "temperature": 0.4,
# "top_p": 0.1,
# },
}
self.response = requests.post(url, json=body, timeout=900)
if self.response.status_code != 200:
return f"Ошибка генерации ответа: {self.response.status_code} {self.response.text}"
return self.response.json().get("response", "").strip()
def generate_answer_stream(self, prompt: str):
url = f"{self.ollama_url}/api/generate"
body = {
"model": self.chat_model,
"prompt": prompt,
"messages": self.conversation_history,
"stream": True
}
resp = requests.post(url, json=body, stream=True, timeout=900)
if resp.status_code != 200:
raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}")
full_answer = ""
for chunk in resp.iter_lines():
if chunk:
try:
decoded_chunk = chunk.decode('utf-8')
data = json.loads(decoded_chunk)
if "response" in data:
yield data["response"]
full_answer += data["response"]
elif "error" in data:
print(f"Stream error: {data['error']}")
break
except json.JSONDecodeError:
print(f"Could not decode JSON from chunk: {chunk.decode('utf-8')}")
except Exception as e:
print(f"Error processing chunk: {e}")
def get_prompt_eval_count(self):
if not self.response:
return 0
return self.response.json().get("prompt_eval_count", 0)
def get_prompt_eval_duration(self):
if not self.response:
return 0
return self.response.json().get("prompt_eval_duration", 0) / (10 ** 9)
def get_eval_count(self):
if not self.response:
return 0
return self.response.json().get("eval_count", 0)
def get_eval_duration(self):
if not self.response:
return 0
return self.response.json().get("eval_duration", 0) / (10 ** 9)
def get_total_duration(self):
if not self.response:
return 0
return self.response.json().get("total_duration", 0) / (10 ** 9)
def get_tps(self):
eval_count = self.get_eval_count()
eval_duration = self.get_eval_duration()
if eval_count == 0 or eval_duration == 0:
return 0
return eval_count / eval_duration
def print_sources(context_docs: list):
print("\n\nИсточники:")
for idx, doc in enumerate(context_docs, start=1):
title = doc['payload'].get("filename", None)
url = doc['payload'].get("url", None)
date = doc['payload'].get("date", None)
version = doc['payload'].get("version", None)
author = doc['payload'].get("author", None)
if url is None:
url = "(нет веб-ссылки)"
if date is None:
date = "(неизвестно)"
if version is None:
version = "0"
if author is None:
author = "(неизвестен)"
print(f"{idx}. {title}")
print(f" {url} (v{version} {author})")
print(f" актуальность на {date}")
def print_v(text: str, is_verbose: bool):
if is_verbose:
print(text)
def print_stats(rag: RagSystem):
print("\n\nСтатистика:")
print(f"* Time: {rag.get_total_duration()}s")
print(f"* TPS: {rag.get_tps()}")
print(f"* PEC: {rag.get_prompt_eval_count()}")
print(f"* PED: {rag.get_prompt_eval_duration()}s")
print(f"* EC: {rag.get_eval_count()}")
print(f"* ED: {rag.get_eval_duration()}s\n")
def main():
import sys
import argparse
parser = argparse.ArgumentParser(description="RAG-система с использованием Ollama и Qdrant")
parser.add_argument("--query", type=str, help="Запрос к RAG")
parser.add_argument("--interactive", default=False, action=argparse.BooleanOptionalAction, help="Перейти в интерактивный режим диалога")
parser.add_argument("--show-prompt", default=False, action=argparse.BooleanOptionalAction, help="Показать полный промпт перед обработкой запроса")
parser.add_argument("--qdrant-host", default="localhost", help="Qdrant host")
parser.add_argument("--qdrant-port", type=int, default=6333, help="Qdrant port")
parser.add_argument("--qdrant-collection", type=str, default="rag", help="Название коллекции для поиска документов")
parser.add_argument("--ollama-url", default="http://localhost:11434", help="Ollama API URL")
parser.add_argument("--emb-model", default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", help="Модель эмбеддинга")
parser.add_argument("--chat-model", default="phi4-mini:3.8b", help="Модель генерации Ollama")
parser.add_argument("--topk", type=int, default=6, help="Количество документов для поиска")
parser.add_argument("--verbose", default=False, action=argparse.BooleanOptionalAction, help="Выводить промежуточные служебные сообщения")
parser.add_argument("--show-stats", default=False, action=argparse.BooleanOptionalAction, help="Выводить статистику об ответе (не работает с --stream)")
parser.add_argument("--stream", default=False, action=argparse.BooleanOptionalAction, help="Выводить статистику об ответе")
args = parser.parse_args()
if not args.query and not args.interactive:
print("Ошибка: укажите запрос (--query) и/или используйте интерактивный режим (--interactive)")
sys.exit(1)
print_v(f"Адрес ollama: {args.ollama_url}", args.verbose)
print_v(f"Адрес qdrant: {args.qdrant_host}:{args.qdrant_port}", args.verbose)
print_v(f"Модель эмбеддинга: {args.emb_model}", args.verbose)
print_v(f"Модель чата: {args.chat_model}", args.verbose)
print_v(f"Документов для поиска: {args.topk}", args.verbose)
print_v(f"Коллекция для поиска: {args.qdrant_collection}", args.verbose)
if os.path.exists('sys_prompt.txt'):
print_v("Будет использоваться sys_prompt.txt!", args.verbose)
print_v("\nПервая инициализация моделей...", args.verbose)
rag = RagSystem(
ollama_url=args.ollama_url,
qdrant_host=args.qdrant_host,
qdrant_port=args.qdrant_port,
embed_model=args.emb_model,
chat_model=args.chat_model
)
print_v(f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG", args.verbose)
query = None
if args.interactive:
print_v("\nИНТЕРАКТИВНЫЙ РЕЖИМ", args.verbose)
print_v("Можете вводить запрос (или 'exit' для выхода)\n", args.verbose)
if args.query:
query = args.query.strip()
print(f">>> {query}")
while True:
try:
if not query or query == "":
query = input(">>> ").strip()
if not query or query == "":
continue
if query.lower() == "help":
print("<<< Команды итерактивного режима:")
print("save -- сохранить диалог в файл")
print("stats -- вывести статистику последнего ответа")
print("exit -- выход\n")
query = None
continue
if query.strip().lower() == "save":
import datetime
timestamp = int(time.time())
dt = datetime.datetime.fromtimestamp(timestamp).strftime('%Y-%m-%dT%H:%M:%SZ')
filename = f"chats/chat-{timestamp}.md"
markdown_content = f"# История диалога от {dt}\n\n"
markdown_content += f"## Параметры диалога\n"
markdown_content += f"```\nargs = {args}\n```\n"
markdown_content += f"```\nemb_model = {rag.emb_model}\n```\n"
for entry in rag.conversation_history:
if entry['role'] == 'user':
markdown_content += f"## Пользователь\n\n"
elif entry['role'] == 'assistant':
markdown_content += f"## Модель\n\n"
docs = rag.prepare_sources(entry['docs']).replace("```", "")
markdown_content += f"```\n{docs}\n```\n\n"
markdown_content += f"{entry['content']}\n\n"
os.makedirs('chats', exist_ok=True)
with open(filename, 'w') as fp:
fp.write(markdown_content)
print(f"<<< Диалог сохранён в файл: {filename}\n")
query = None
continue
if query.strip().lower() == "exit":
print_v("\n*** Завершение работы", args.verbose)
break
print_v("\nПоиск релевантных документов...", args.verbose)
context_docs = rag.search_qdrant(query, top_k=args.topk, qdrant_collection=args.qdrant_collection)
if not context_docs:
print("<<< Релевантные документы не найдены")
if args.interactive:
query = None
continue
else:
break
print_v(f"Найдено {len(context_docs)} релевантных документов", args.verbose)
# print_sources(context_docs)
prompt = rag.prepare_prompt(query=query, context_docs=context_docs)
if args.show_prompt:
print("\nПолный системный промпт: --------------------------")
print(f"{prompt}\n---------------------------------------------------")
print_v("\nГенерация ответа...\n", args.verbose)
if args.stream:
answer = "\n<<< "
print(answer, end='', flush=True)
try:
for message_part in rag.generate_answer_stream(prompt):
answer += message_part
print(message_part, end='', flush=True)
except RuntimeError as e:
answer = str(e)
print(f"\n{answer}\n===================================================\n")
else:
answer = rag.generate_answer(prompt)
print(f"<<< {answer}\n")
print_sources(context_docs)
if args.show_stats and not args.stream:
print_stats(rag)
rag.conversation_history.append({
"role": "user",
"content": query,
})
rag.conversation_history.append({
"role": "assistant",
"docs": context_docs,
"content": answer,
})
if args.interactive:
query = None
else:
break
except KeyboardInterrupt:
print("\n*** Завершение работы")
break
except Exception as e:
print(f"Ошибка: {e}")
break
if __name__ == "__main__":
main()