Исправлено поведение rag.py --show-prompt
This commit is contained in:
@@ -188,7 +188,7 @@ python3 rag.py --help
|
||||
|
||||
Если хочется уточнить роль генеративной модели, можно создать файл `sys_prompt.txt` и прописать туда всё необходимое, учитывая следующие правила:
|
||||
|
||||
1. Шаблон `{{context}}` будет заменён на цитаты документов, найденные в qdrant
|
||||
1. Шаблон `{{sources}}` будет заменён на цитаты документов, найденные в qdrant
|
||||
2. Шаблон `{{query}}` будет заменён на запрос пользователя
|
||||
3. Если этих двух шаблонов не будет в промпте, результаты будут непредсказуемыми
|
||||
4. Каждая цитата в списке цитат формируется в формате:
|
||||
|
||||
106
rag/rag.py
106
rag/rag.py
@@ -19,6 +19,7 @@ class LocalRAGSystem:
|
||||
self.embed_model = embed_model
|
||||
self.chat_model = chat_model
|
||||
self.emb_model = SentenceTransformer(embed_model)
|
||||
self.prompt = ""
|
||||
|
||||
def get_embedding(self, text: str):
|
||||
return self.emb_model.encode(text, show_progress_bar=False).tolist()
|
||||
@@ -37,31 +38,32 @@ class LocalRAGSystem:
|
||||
results = resp.json().get("result", [])
|
||||
return results
|
||||
|
||||
def generate_answer(self, query: str, context_docs: list):
|
||||
query = query.strip()
|
||||
context = f""
|
||||
sources = f"\nИсточники:\n"
|
||||
def generate_answer(self, prompt: str):
|
||||
url = f"{self.ollama_url}/api/generate"
|
||||
body = {
|
||||
"model": self.chat_model,
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}
|
||||
resp = requests.post(url, json=body, timeout=600)
|
||||
if resp.status_code != 200:
|
||||
return f"Ошибка генерации ответа: {resp.status_code} {resp.text}"
|
||||
return resp.json().get("response", "").strip()
|
||||
|
||||
def prepare_sources(self, context_docs: list):
|
||||
sources = ""
|
||||
for idx, doc in enumerate(context_docs, start=1):
|
||||
text = doc['payload'].get("text", "").strip()
|
||||
filename = doc['payload'].get("filename", None)
|
||||
url = doc['payload'].get("url", None)
|
||||
if filename:
|
||||
title = filename
|
||||
else:
|
||||
snippet = text[:40].replace("\n", " ").strip()
|
||||
if len(text) > 40:
|
||||
snippet += "..."
|
||||
title = snippet if snippet else "Empty text"
|
||||
if url is None:
|
||||
url = ""
|
||||
context = f"{context}\n--- Source [{idx}] ---\n{text}\n"
|
||||
sources = f"{sources}\n{idx}. {title}\n {url}"
|
||||
sources = f"{sources}\n--- Source [{idx}] ---\n{text}\n"
|
||||
return sources.strip()
|
||||
|
||||
def prepare_prompt(self, query: str, context_docs: list):
|
||||
sources = self.prepare_sources(context_docs)
|
||||
if os.path.exists('sys_prompt.txt'):
|
||||
with open('sys_prompt.txt', 'r') as fp:
|
||||
prompt = fp.read().replace("{{context}}", context).replace("{{query}}", query)
|
||||
return fp.read().replace("{{sources}}", sources).replace("{{query}}", query)
|
||||
else:
|
||||
prompt = f"""
|
||||
return f"""
|
||||
Please provide an answer based solely on the provided sources.
|
||||
It is prohibited to generate an answer based on your pretrained data.
|
||||
If uncertain, ask the user for clarification.
|
||||
@@ -91,17 +93,14 @@ class LocalRAGSystem:
|
||||
Your answer:
|
||||
"""
|
||||
|
||||
url = f"{self.ollama_url}/api/generate"
|
||||
body = {
|
||||
"model": self.chat_model,
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}
|
||||
resp = requests.post(url, json=body, timeout=600)
|
||||
if resp.status_code != 200:
|
||||
return f"Ошибка генерации ответа: {resp.status_code} {resp.text}"
|
||||
return resp.json().get("response", "").strip() + f"\n{sources}"
|
||||
|
||||
def print_sources(context_docs: list):
|
||||
for idx, doc in enumerate(context_docs, start=1):
|
||||
filename = doc['payload'].get("filename", None)
|
||||
url = doc['payload'].get("url", None)
|
||||
title = filename
|
||||
if url is None:
|
||||
url = "(нет веб-ссылки)"
|
||||
print(f"{idx}. {title}\n {url}")
|
||||
|
||||
def main():
|
||||
import sys
|
||||
@@ -129,14 +128,7 @@ def main():
|
||||
print(f"Модель чата: {args.chat_model}")
|
||||
print(f"Документов для поиска: {args.topk}")
|
||||
if os.path.exists('sys_prompt.txt'):
|
||||
print("Будет загружен системный промпт из sys_prompt.txt!")
|
||||
|
||||
if args.interactive:
|
||||
print("\nИНТЕРАКТИВНЫЙ РЕЖИМ")
|
||||
print("Можете вводить запрос (или 'exit' для выхода)\n")
|
||||
question = input(">>> ").strip()
|
||||
else:
|
||||
question = args.query.strip()
|
||||
print("Будет использоваться sys_prompt.txt!")
|
||||
|
||||
print("\nПервая инициализация моделей...")
|
||||
rag = LocalRAGSystem(
|
||||
@@ -146,45 +138,57 @@ def main():
|
||||
embed_model=args.emb_model,
|
||||
chat_model=args.chat_model
|
||||
)
|
||||
|
||||
print(f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG")
|
||||
|
||||
if args.interactive:
|
||||
print("\nИНТЕРАКТИВНЫЙ РЕЖИМ")
|
||||
print("Можете вводить запрос (или 'exit' для выхода)\n")
|
||||
|
||||
if args.query:
|
||||
query = args.query.strip()
|
||||
print(f">>> {query}")
|
||||
else:
|
||||
query = input(">>> ").strip()
|
||||
|
||||
while True:
|
||||
try:
|
||||
if not question or question == "":
|
||||
question = input(">>> ").strip()
|
||||
if not query or query == "":
|
||||
query = input(">>> ").strip()
|
||||
|
||||
if not question or question == "":
|
||||
if not query or query == "":
|
||||
continue
|
||||
|
||||
if question.lower() == "exit":
|
||||
if query.lower() == "exit":
|
||||
print("\n*** Завершение работы")
|
||||
break
|
||||
|
||||
print("\nПоиск релевантных документов...")
|
||||
results = rag.search_qdrant(question, top_k=args.topk)
|
||||
if not results:
|
||||
context_docs = rag.search_qdrant(query, top_k=args.topk)
|
||||
if not context_docs:
|
||||
print("Релевантные документы не найдены.")
|
||||
if args.interactive:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
print(f"Найдено {len(results)} релевантных документов")
|
||||
print(f"Найдено {len(context_docs)} релевантных документов:")
|
||||
print_sources(context_docs)
|
||||
|
||||
prompt = rag.prepare_prompt(query=query, context_docs=context_docs)
|
||||
if args.show_prompt:
|
||||
print("\nПолный системный промпт:\n")
|
||||
print(rag.prompt)
|
||||
print("\nПолный системный промпт: --------------------------\n")
|
||||
print(f"{prompt}\n---------------------------------------------------\n")
|
||||
|
||||
print("Генерация ответа...")
|
||||
answer = rag.generate_answer(question, results)
|
||||
print(f"\n<<< {answer}\n---------------------------------------------------\n")
|
||||
question = None
|
||||
answer = rag.generate_answer(prompt)
|
||||
print(f"\n<<< {answer}\n===================================================\n")
|
||||
query = None
|
||||
except KeyboardInterrupt:
|
||||
print("\n*** Завершение работы")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Ошибка: {e}")
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user