1
0

Исправлено поведение rag.py --show-prompt

This commit is contained in:
2025-08-25 21:00:09 +08:00
parent 22d23c1ca0
commit 3b15a6a19e
2 changed files with 56 additions and 52 deletions

View File

@@ -188,7 +188,7 @@ python3 rag.py --help
Если хочется уточнить роль генеративной модели, можно создать файл `sys_prompt.txt` и прописать туда всё необходимое, учитывая следующие правила: Если хочется уточнить роль генеративной модели, можно создать файл `sys_prompt.txt` и прописать туда всё необходимое, учитывая следующие правила:
1. Шаблон `{{context}}` будет заменён на цитаты документов, найденные в qdrant 1. Шаблон `{{sources}}` будет заменён на цитаты документов, найденные в qdrant
2. Шаблон `{{query}}` будет заменён на запрос пользователя 2. Шаблон `{{query}}` будет заменён на запрос пользователя
3. Если этих двух шаблонов не будет в промпте, результаты будут непредсказуемыми 3. Если этих двух шаблонов не будет в промпте, результаты будут непредсказуемыми
4. Каждая цитата в списке цитат формируется в формате: 4. Каждая цитата в списке цитат формируется в формате:

View File

@@ -19,6 +19,7 @@ class LocalRAGSystem:
self.embed_model = embed_model self.embed_model = embed_model
self.chat_model = chat_model self.chat_model = chat_model
self.emb_model = SentenceTransformer(embed_model) self.emb_model = SentenceTransformer(embed_model)
self.prompt = ""
def get_embedding(self, text: str): def get_embedding(self, text: str):
return self.emb_model.encode(text, show_progress_bar=False).tolist() return self.emb_model.encode(text, show_progress_bar=False).tolist()
@@ -37,31 +38,32 @@ class LocalRAGSystem:
results = resp.json().get("result", []) results = resp.json().get("result", [])
return results return results
def generate_answer(self, query: str, context_docs: list): def generate_answer(self, prompt: str):
query = query.strip() url = f"{self.ollama_url}/api/generate"
context = f"" body = {
sources = f"\nИсточники:\n" "model": self.chat_model,
"prompt": prompt,
"stream": False
}
resp = requests.post(url, json=body, timeout=600)
if resp.status_code != 200:
return f"Ошибка генерации ответа: {resp.status_code} {resp.text}"
return resp.json().get("response", "").strip()
def prepare_sources(self, context_docs: list):
sources = ""
for idx, doc in enumerate(context_docs, start=1): for idx, doc in enumerate(context_docs, start=1):
text = doc['payload'].get("text", "").strip() text = doc['payload'].get("text", "").strip()
filename = doc['payload'].get("filename", None) sources = f"{sources}\n--- Source [{idx}] ---\n{text}\n"
url = doc['payload'].get("url", None) return sources.strip()
if filename:
title = filename
else:
snippet = text[:40].replace("\n", " ").strip()
if len(text) > 40:
snippet += "..."
title = snippet if snippet else "Empty text"
if url is None:
url = ""
context = f"{context}\n--- Source [{idx}] ---\n{text}\n"
sources = f"{sources}\n{idx}. {title}\n {url}"
def prepare_prompt(self, query: str, context_docs: list):
sources = self.prepare_sources(context_docs)
if os.path.exists('sys_prompt.txt'): if os.path.exists('sys_prompt.txt'):
with open('sys_prompt.txt', 'r') as fp: with open('sys_prompt.txt', 'r') as fp:
prompt = fp.read().replace("{{context}}", context).replace("{{query}}", query) return fp.read().replace("{{sources}}", sources).replace("{{query}}", query)
else: else:
prompt = f""" return f"""
Please provide an answer based solely on the provided sources. Please provide an answer based solely on the provided sources.
It is prohibited to generate an answer based on your pretrained data. It is prohibited to generate an answer based on your pretrained data.
If uncertain, ask the user for clarification. If uncertain, ask the user for clarification.
@@ -91,17 +93,14 @@ class LocalRAGSystem:
Your answer: Your answer:
""" """
url = f"{self.ollama_url}/api/generate" def print_sources(context_docs: list):
body = { for idx, doc in enumerate(context_docs, start=1):
"model": self.chat_model, filename = doc['payload'].get("filename", None)
"prompt": prompt, url = doc['payload'].get("url", None)
"stream": False title = filename
} if url is None:
resp = requests.post(url, json=body, timeout=600) url = "(нет веб-ссылки)"
if resp.status_code != 200: print(f"{idx}. {title}\n {url}")
return f"Ошибка генерации ответа: {resp.status_code} {resp.text}"
return resp.json().get("response", "").strip() + f"\n{sources}"
def main(): def main():
import sys import sys
@@ -129,14 +128,7 @@ def main():
print(f"Модель чата: {args.chat_model}") print(f"Модель чата: {args.chat_model}")
print(f"Документов для поиска: {args.topk}") print(f"Документов для поиска: {args.topk}")
if os.path.exists('sys_prompt.txt'): if os.path.exists('sys_prompt.txt'):
print("Будет загружен системный промпт из sys_prompt.txt!") print("Будет использоваться sys_prompt.txt!")
if args.interactive:
print("\nИНТЕРАКТИВНЫЙ РЕЖИМ")
print("Можете вводить запрос (или 'exit' для выхода)\n")
question = input(">>> ").strip()
else:
question = args.query.strip()
print("\nПервая инициализация моделей...") print("\nПервая инициализация моделей...")
rag = LocalRAGSystem( rag = LocalRAGSystem(
@@ -146,45 +138,57 @@ def main():
embed_model=args.emb_model, embed_model=args.emb_model,
chat_model=args.chat_model chat_model=args.chat_model
) )
print(f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG") print(f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG")
if args.interactive:
print("\nИНТЕРАКТИВНЫЙ РЕЖИМ")
print("Можете вводить запрос (или 'exit' для выхода)\n")
if args.query:
query = args.query.strip()
print(f">>> {query}")
else:
query = input(">>> ").strip()
while True: while True:
try: try:
if not question or question == "": if not query or query == "":
question = input(">>> ").strip() query = input(">>> ").strip()
if not question or question == "": if not query or query == "":
continue continue
if question.lower() == "exit": if query.lower() == "exit":
print("\n*** Завершение работы") print("\n*** Завершение работы")
break break
print("\nПоиск релевантных документов...") print("\nПоиск релевантных документов...")
results = rag.search_qdrant(question, top_k=args.topk) context_docs = rag.search_qdrant(query, top_k=args.topk)
if not results: if not context_docs:
print("Релевантные документы не найдены.") print("Релевантные документы не найдены.")
if args.interactive: if args.interactive:
continue continue
else: else:
break break
print(f"Найдено {len(results)} релевантных документов") print(f"Найдено {len(context_docs)} релевантных документов:")
print_sources(context_docs)
prompt = rag.prepare_prompt(query=query, context_docs=context_docs)
if args.show_prompt: if args.show_prompt:
print("\nПолный системный промпт:\n") print("\nПолный системный промпт: --------------------------\n")
print(rag.prompt) print(f"{prompt}\n---------------------------------------------------\n")
print("Генерация ответа...") print("Генерация ответа...")
answer = rag.generate_answer(question, results) answer = rag.generate_answer(prompt)
print(f"\n<<< {answer}\n---------------------------------------------------\n") print(f"\n<<< {answer}\n===================================================\n")
question = None query = None
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n*** Завершение работы") print("\n*** Завершение работы")
break break
except Exception as e: except Exception as e:
print(f"Ошибка: {e}") print(f"Ошибка: {e}")
break
if __name__ == "__main__": if __name__ == "__main__":
main() main()