Исправлено поведение rag.py --show-prompt
This commit is contained in:
@@ -188,7 +188,7 @@ python3 rag.py --help
|
|||||||
|
|
||||||
Если хочется уточнить роль генеративной модели, можно создать файл `sys_prompt.txt` и прописать туда всё необходимое, учитывая следующие правила:
|
Если хочется уточнить роль генеративной модели, можно создать файл `sys_prompt.txt` и прописать туда всё необходимое, учитывая следующие правила:
|
||||||
|
|
||||||
1. Шаблон `{{context}}` будет заменён на цитаты документов, найденные в qdrant
|
1. Шаблон `{{sources}}` будет заменён на цитаты документов, найденные в qdrant
|
||||||
2. Шаблон `{{query}}` будет заменён на запрос пользователя
|
2. Шаблон `{{query}}` будет заменён на запрос пользователя
|
||||||
3. Если этих двух шаблонов не будет в промпте, результаты будут непредсказуемыми
|
3. Если этих двух шаблонов не будет в промпте, результаты будут непредсказуемыми
|
||||||
4. Каждая цитата в списке цитат формируется в формате:
|
4. Каждая цитата в списке цитат формируется в формате:
|
||||||
|
|||||||
106
rag/rag.py
106
rag/rag.py
@@ -19,6 +19,7 @@ class LocalRAGSystem:
|
|||||||
self.embed_model = embed_model
|
self.embed_model = embed_model
|
||||||
self.chat_model = chat_model
|
self.chat_model = chat_model
|
||||||
self.emb_model = SentenceTransformer(embed_model)
|
self.emb_model = SentenceTransformer(embed_model)
|
||||||
|
self.prompt = ""
|
||||||
|
|
||||||
def get_embedding(self, text: str):
|
def get_embedding(self, text: str):
|
||||||
return self.emb_model.encode(text, show_progress_bar=False).tolist()
|
return self.emb_model.encode(text, show_progress_bar=False).tolist()
|
||||||
@@ -37,31 +38,32 @@ class LocalRAGSystem:
|
|||||||
results = resp.json().get("result", [])
|
results = resp.json().get("result", [])
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def generate_answer(self, query: str, context_docs: list):
|
def generate_answer(self, prompt: str):
|
||||||
query = query.strip()
|
url = f"{self.ollama_url}/api/generate"
|
||||||
context = f""
|
body = {
|
||||||
sources = f"\nИсточники:\n"
|
"model": self.chat_model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
resp = requests.post(url, json=body, timeout=600)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return f"Ошибка генерации ответа: {resp.status_code} {resp.text}"
|
||||||
|
return resp.json().get("response", "").strip()
|
||||||
|
|
||||||
|
def prepare_sources(self, context_docs: list):
|
||||||
|
sources = ""
|
||||||
for idx, doc in enumerate(context_docs, start=1):
|
for idx, doc in enumerate(context_docs, start=1):
|
||||||
text = doc['payload'].get("text", "").strip()
|
text = doc['payload'].get("text", "").strip()
|
||||||
filename = doc['payload'].get("filename", None)
|
sources = f"{sources}\n--- Source [{idx}] ---\n{text}\n"
|
||||||
url = doc['payload'].get("url", None)
|
return sources.strip()
|
||||||
if filename:
|
|
||||||
title = filename
|
|
||||||
else:
|
|
||||||
snippet = text[:40].replace("\n", " ").strip()
|
|
||||||
if len(text) > 40:
|
|
||||||
snippet += "..."
|
|
||||||
title = snippet if snippet else "Empty text"
|
|
||||||
if url is None:
|
|
||||||
url = ""
|
|
||||||
context = f"{context}\n--- Source [{idx}] ---\n{text}\n"
|
|
||||||
sources = f"{sources}\n{idx}. {title}\n {url}"
|
|
||||||
|
|
||||||
|
def prepare_prompt(self, query: str, context_docs: list):
|
||||||
|
sources = self.prepare_sources(context_docs)
|
||||||
if os.path.exists('sys_prompt.txt'):
|
if os.path.exists('sys_prompt.txt'):
|
||||||
with open('sys_prompt.txt', 'r') as fp:
|
with open('sys_prompt.txt', 'r') as fp:
|
||||||
prompt = fp.read().replace("{{context}}", context).replace("{{query}}", query)
|
return fp.read().replace("{{sources}}", sources).replace("{{query}}", query)
|
||||||
else:
|
else:
|
||||||
prompt = f"""
|
return f"""
|
||||||
Please provide an answer based solely on the provided sources.
|
Please provide an answer based solely on the provided sources.
|
||||||
It is prohibited to generate an answer based on your pretrained data.
|
It is prohibited to generate an answer based on your pretrained data.
|
||||||
If uncertain, ask the user for clarification.
|
If uncertain, ask the user for clarification.
|
||||||
@@ -91,17 +93,14 @@ class LocalRAGSystem:
|
|||||||
Your answer:
|
Your answer:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url = f"{self.ollama_url}/api/generate"
|
def print_sources(context_docs: list):
|
||||||
body = {
|
for idx, doc in enumerate(context_docs, start=1):
|
||||||
"model": self.chat_model,
|
filename = doc['payload'].get("filename", None)
|
||||||
"prompt": prompt,
|
url = doc['payload'].get("url", None)
|
||||||
"stream": False
|
title = filename
|
||||||
}
|
if url is None:
|
||||||
resp = requests.post(url, json=body, timeout=600)
|
url = "(нет веб-ссылки)"
|
||||||
if resp.status_code != 200:
|
print(f"{idx}. {title}\n {url}")
|
||||||
return f"Ошибка генерации ответа: {resp.status_code} {resp.text}"
|
|
||||||
return resp.json().get("response", "").strip() + f"\n{sources}"
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import sys
|
import sys
|
||||||
@@ -129,14 +128,7 @@ def main():
|
|||||||
print(f"Модель чата: {args.chat_model}")
|
print(f"Модель чата: {args.chat_model}")
|
||||||
print(f"Документов для поиска: {args.topk}")
|
print(f"Документов для поиска: {args.topk}")
|
||||||
if os.path.exists('sys_prompt.txt'):
|
if os.path.exists('sys_prompt.txt'):
|
||||||
print("Будет загружен системный промпт из sys_prompt.txt!")
|
print("Будет использоваться sys_prompt.txt!")
|
||||||
|
|
||||||
if args.interactive:
|
|
||||||
print("\nИНТЕРАКТИВНЫЙ РЕЖИМ")
|
|
||||||
print("Можете вводить запрос (или 'exit' для выхода)\n")
|
|
||||||
question = input(">>> ").strip()
|
|
||||||
else:
|
|
||||||
question = args.query.strip()
|
|
||||||
|
|
||||||
print("\nПервая инициализация моделей...")
|
print("\nПервая инициализация моделей...")
|
||||||
rag = LocalRAGSystem(
|
rag = LocalRAGSystem(
|
||||||
@@ -146,45 +138,57 @@ def main():
|
|||||||
embed_model=args.emb_model,
|
embed_model=args.emb_model,
|
||||||
chat_model=args.chat_model
|
chat_model=args.chat_model
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG")
|
print(f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG")
|
||||||
|
|
||||||
|
if args.interactive:
|
||||||
|
print("\nИНТЕРАКТИВНЫЙ РЕЖИМ")
|
||||||
|
print("Можете вводить запрос (или 'exit' для выхода)\n")
|
||||||
|
|
||||||
|
if args.query:
|
||||||
|
query = args.query.strip()
|
||||||
|
print(f">>> {query}")
|
||||||
|
else:
|
||||||
|
query = input(">>> ").strip()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if not question or question == "":
|
if not query or query == "":
|
||||||
question = input(">>> ").strip()
|
query = input(">>> ").strip()
|
||||||
|
|
||||||
if not question or question == "":
|
if not query or query == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if question.lower() == "exit":
|
if query.lower() == "exit":
|
||||||
print("\n*** Завершение работы")
|
print("\n*** Завершение работы")
|
||||||
break
|
break
|
||||||
|
|
||||||
print("\nПоиск релевантных документов...")
|
print("\nПоиск релевантных документов...")
|
||||||
results = rag.search_qdrant(question, top_k=args.topk)
|
context_docs = rag.search_qdrant(query, top_k=args.topk)
|
||||||
if not results:
|
if not context_docs:
|
||||||
print("Релевантные документы не найдены.")
|
print("Релевантные документы не найдены.")
|
||||||
if args.interactive:
|
if args.interactive:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
print(f"Найдено {len(results)} релевантных документов")
|
print(f"Найдено {len(context_docs)} релевантных документов:")
|
||||||
|
print_sources(context_docs)
|
||||||
|
|
||||||
|
prompt = rag.prepare_prompt(query=query, context_docs=context_docs)
|
||||||
if args.show_prompt:
|
if args.show_prompt:
|
||||||
print("\nПолный системный промпт:\n")
|
print("\nПолный системный промпт: --------------------------\n")
|
||||||
print(rag.prompt)
|
print(f"{prompt}\n---------------------------------------------------\n")
|
||||||
|
|
||||||
print("Генерация ответа...")
|
print("Генерация ответа...")
|
||||||
answer = rag.generate_answer(question, results)
|
answer = rag.generate_answer(prompt)
|
||||||
print(f"\n<<< {answer}\n---------------------------------------------------\n")
|
print(f"\n<<< {answer}\n===================================================\n")
|
||||||
question = None
|
query = None
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\n*** Завершение работы")
|
print("\n*** Завершение работы")
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Ошибка: {e}")
|
print(f"Ошибка: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user