Доработка rag, удаление скриптов моделей, актуализация README
This commit is contained in:
17
rag/CHECKLIST.md
Normal file
17
rag/CHECKLIST.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# Чек-лист по построению RAG
|
||||
|
||||
* [ ] Определиться с форматом входных данных
|
||||
* [ ] Очистить входные данные, обеспечив метаданными
|
||||
* [ ] Подобрать модель эмбеддинга
|
||||
* [ ] Подобрать размер чанка и перекрытия для эмбеддинга
|
||||
* [ ] Подобрать место хранения (векторная СУБД)
|
||||
* [ ] Подобрать модель ранжирования
|
||||
* [ ] Подобрать модель генерации
|
||||
* [ ] Подобрать для неё системный промпт (для встраивания найденных чанков, грамотного их цитирования)
|
||||
* [ ] Подобрать параметры:
|
||||
* [ ] top_k (количество чанков для поиска при эмбеддинге)
|
||||
* [ ] top_n (остаток найденных чанков после ранжирования)
|
||||
* [ ] temperature (степень фантазии)
|
||||
* [ ] top_p (???)
|
||||
* [ ] другие?
|
||||
* [ ]
|
||||
174
rag/README.md
174
rag/README.md
@@ -7,10 +7,10 @@
|
||||
```bash
|
||||
cd ..; ./up; cd -
|
||||
python3 -m venv .venv
|
||||
source ./venv/bin/activate
|
||||
pip install beautifulsoup4 markdownify sentence-transformers qdrant-client langchain transformers hashlib
|
||||
source .venv/bin/activate
|
||||
pip install beautifulsoup4 markdownify sentence-transformers qdrant-client langchain transformers
|
||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
./download.sh 123456789
|
||||
./download.sh 123456789 # <<== pageId страницы в Confluence
|
||||
python3 convert.py
|
||||
python3 vectorize.py
|
||||
python3 rag.py --interactive
|
||||
@@ -153,10 +153,6 @@ python3 vectorize.py
|
||||
- молниеносный поиск по индексу чанков (частям документов);
|
||||
- корректное насыщение контекста для генеративной модели.
|
||||
|
||||
Впоследствии embedding-модель будет встраивать эти данные в диалог с генеративной моделью.
|
||||
Каждый запрос сначала будет обрабатывать именно она, находя подходящие по векторам документы, и подставлять их в контекст генеративной модели.
|
||||
Последняя будет всего лишь генерировать ответ, опираясь на предоставленные из документов данные, ссылаясь на них в ответе.
|
||||
|
||||
Для получения справки по скрипту выполни команду:
|
||||
|
||||
```
|
||||
@@ -192,41 +188,167 @@ python3 rag.py --help
|
||||
|
||||
### Кастомный системный промпт
|
||||
|
||||
Если хочется уточнить роль генеративной модели, можно создать файл `sys_prompt.txt` и прописать туда всё необходимое, учитывая следующие правила:
|
||||
Если хочется уточнить роль генеративной модели, можно создать текстовый файл и прописать туда всё необходимое, учитывая следующие правила:
|
||||
|
||||
1. Шаблон `{{sources}}` будет заменён на цитаты документов, найденные в qdrant
|
||||
1. Шаблон `{{sources}}` будет заменён на цитаты документов, найденных в qdrant
|
||||
2. Шаблон `{{query}}` будет заменён на запрос пользователя
|
||||
3. Если этих двух шаблонов не будет в промпте, результаты будут непредсказуемыми
|
||||
4. Каждая цитата в списке цитат формируется в формате:
|
||||
```
|
||||
--- Source X ---
|
||||
```xml
|
||||
<source id="Z">
|
||||
Lorem ipsum dolor sit amet
|
||||
<пустая строка>
|
||||
</source>
|
||||
```
|
||||
5. Если в этой директории нет файла `sys_prompt.txt`, то будет применён промпт по умолчанию (см. функцию `generate_answer()`).
|
||||
5. При вызове `rag.py` указать путь к файлу промпта, используя аргумент `--sys-prompt $путь_к_файлу`
|
||||
6. Если указанного файла не существует, то будет применён промпт по умолчанию.
|
||||
|
||||
Посмотреть полный промпт можно указав аргумент `--show_prompt` при вызове `rag.py`.
|
||||
|
||||
### Неплохие лёгкие модели
|
||||
### Неплохие модели для экспериментов
|
||||
|
||||
Для эмбеддинга:
|
||||
Обозначения:
|
||||
* ☑️ — по умолчанию
|
||||
* 🧠 — размышляющая
|
||||
* 🏋️ — требуются ресурсы
|
||||
|
||||
- `sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2` (по умолчанию, хорошо адаптирована под русский язык)
|
||||
- `nomad-embed-text` (популярная)
|
||||
#### Эмбеддинг
|
||||
|
||||
- [`sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2`](https://hf.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2) ☑️
|
||||
- [`nomad-embed-text`](https://ollama.com/library/nomad-embed-text)
|
||||
- ...
|
||||
|
||||
Для генерации ответов:
|
||||
#### Ранжирование
|
||||
|
||||
- `qwen2.5:3b` (по умолчанию)
|
||||
- `qwen3:8b`
|
||||
- `gemma3n:e2b`
|
||||
- `phi4-mini:3.8b`
|
||||
- `qwen2.5:1.5b`
|
||||
- [`cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1`](https://hf.co/cross-encoder/ms-marco-MMarco-mMiniLMv2-L12-V1) ☑️
|
||||
- `cross-encoder/ms-marco-MiniLM-L-6-v2`
|
||||
- `cross-encoder/ms-marco-TinyBERT-L-2-v2`
|
||||
- ...
|
||||
|
||||
> [!NOTE]
|
||||
> Чем меньше млрд параметров (b, billion), тем меньше вероятности получить корректный ответ на не-английском языке.
|
||||
> Такие модели работают быстро, но качество ответов низкое.
|
||||
> Чем больше параметров, тем лучше и медленее ответы.
|
||||
> Другие можно найти здесь: https://github.com/AlexeyMalafeev/ruformers
|
||||
|
||||
#### Генеративные
|
||||
|
||||
Перечислен список: по убыванию качества ответов и размера модели, по возрастанию скорости ответов на обычном домашнем ПК.
|
||||
|
||||
- [`deepseek-r1:8b`](https://ollama.com/library/deepseek-r1) 🏋️🧠
|
||||
- [`qwen3:8b`](https://ollama.com/library/qwen3) 🏋️🧠
|
||||
- [`dolphin3:8b`](https://ollama.com/library/dolphin3)🏋️
|
||||
- [`cogito:8b`](https://ollama.com/library/cogito)🏋️
|
||||
- [`openchat:7b`](https://ollama.com/library/openchat) 🏋️☑️
|
||||
- [`phi4-mini:3.8b`](https://ollama.com/library/phi4-mini)
|
||||
- [`gemma3:4b`](https://ollama.com/library/gemma3)
|
||||
- [`gemma3n:e4b`](https://ollama.com/library/gemma3n)
|
||||
- [`gemma3n:e2b`](https://ollama.com/library/gemma3n)
|
||||
|
||||
<details>
|
||||
<summary>Полный список лёгких и средних моделей, которые можно использовать не только в RAG</summary>
|
||||
|
||||
```
|
||||
codegemma:2b
|
||||
codegemma:7b
|
||||
codellama:7b
|
||||
codellama:13b
|
||||
codellama:34b
|
||||
codeqwen:1.5b
|
||||
codeqwen:7b
|
||||
codestral:22b
|
||||
deepcoder:1.5b
|
||||
deepcoder:14b
|
||||
deepseek-coder:1.3b
|
||||
deepseek-coder:6.7b
|
||||
deepseek-coder:33b
|
||||
deepseek-coder-v2:16b
|
||||
deepseek-r1:1.5b
|
||||
deepseek-r1:7b
|
||||
deepseek-r1:8b
|
||||
deepseek-r1:14b
|
||||
deepseek-r1:32b
|
||||
devstral:24b
|
||||
dolphin3:8b
|
||||
gemma:2b
|
||||
gemma:7b
|
||||
gemma3:1b
|
||||
gemma3:4b
|
||||
gemma3:12b
|
||||
gemma3:27b
|
||||
gemma3:270m
|
||||
gemma3n:e2b
|
||||
gemma3n:e4b
|
||||
gpt-oss:20b
|
||||
granite-code:3b
|
||||
granite-code:8b
|
||||
granite-code:20b
|
||||
granite-code:34b
|
||||
llama2:7b
|
||||
llama2:13b
|
||||
llama3:8b
|
||||
llama3.1:8b
|
||||
llama3.2:1b
|
||||
llama3.2:3b
|
||||
llava-llama3:8b
|
||||
magistral:24b
|
||||
mistral:7b
|
||||
mistral-nemo:12b
|
||||
mistral-small:22b
|
||||
mistral-small:24b
|
||||
mixtral:8x7b
|
||||
mxbai-embed-large:latest
|
||||
nomic-embed-text:latest
|
||||
openthinker:7b
|
||||
openthinker:32b
|
||||
phi:2.7b
|
||||
phi3:3.8b
|
||||
phi3:14b
|
||||
phi3:instruct
|
||||
phi3:medium
|
||||
phi3:mini
|
||||
phi3.5:3.8b
|
||||
phi4:14b
|
||||
phi4-mini-reasoning:3.8b
|
||||
phi4-mini:3.8b
|
||||
phi4-reasoning:14b
|
||||
qwen:0.5b
|
||||
qwen:1.8b
|
||||
qwen:4b
|
||||
qwen:7b
|
||||
qwen:14b
|
||||
qwen:32b
|
||||
qwen2:0.5b
|
||||
qwen2:1.5b
|
||||
qwen2:7b
|
||||
qwen2.5:0.5b
|
||||
qwen2.5:1.5b
|
||||
qwen2.5:3b
|
||||
qwen2.5:7b
|
||||
qwen2.5:14b
|
||||
qwen2.5:32b
|
||||
qwen2.5-coder:0.5b
|
||||
qwen2.5-coder:1.5b
|
||||
qwen2.5-coder:3b
|
||||
qwen2.5-coder:7b
|
||||
qwen2.5-coder:14b
|
||||
qwen2.5-coder:32b
|
||||
qwen3:0.6b
|
||||
qwen3:1.7b
|
||||
qwen3:4b
|
||||
qwen3:8b
|
||||
qwen3:14b
|
||||
qwen3:30b
|
||||
qwen3:32b
|
||||
qwen3-coder:30b
|
||||
qwq:32b
|
||||
smollm2:1.7m
|
||||
smollm2:135m
|
||||
smollm2:360m
|
||||
stable-code:3b
|
||||
stable-code:instruct
|
||||
starcoder2:3b
|
||||
starcoder2:7b
|
||||
starcoder2:15b
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Дисклеймер
|
||||
|
||||
|
||||
704
rag/rag.py
704
rag/rag.py
@@ -2,24 +2,49 @@ import os
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import sys
|
||||
from qdrant_client import QdrantClient
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
|
||||
DEFAULT_CHAT_MODEL = "phi4-mini:3.8b"
|
||||
DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
||||
# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
|
||||
DEFAULT_MD_FOLDER = "data"
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
DEFAULT_QDRANT_HOST = "localhost"
|
||||
DEFAULT_QDRANT_PORT = 6333
|
||||
DEFAULT_QDRANT_COLLECTION = "rag"
|
||||
DEFAULT_TOP_K = 30
|
||||
DEFAULT_USE_RANK = False
|
||||
DEFAULT_TOP_N = 8
|
||||
DEFAULT_VERBOSE = False
|
||||
DEFAULT_SHOW_STATS = False
|
||||
DEFAULT_STREAM = False
|
||||
DEFAULT_INTERACTIVE = False
|
||||
DEFAULT_SHOW_PROMPT = False
|
||||
|
||||
class RagSystem:
|
||||
def __init__(self,
|
||||
md_folder: str = "data",
|
||||
ollama_url: str = "http://localhost:11434",
|
||||
qdrant_host: str = "localhost",
|
||||
qdrant_port: int = 6333,
|
||||
embed_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
||||
chat_model: str = "phi4-mini:3.8b"):
|
||||
self.md_folder = md_folder
|
||||
ollama_url: str = DEFAULT_OLLAMA_URL,
|
||||
qdrant_host: str = DEFAULT_QDRANT_HOST,
|
||||
qdrant_port: int = DEFAULT_QDRANT_PORT,
|
||||
embed_model: str = DEFAULT_EMBED_MODEL,
|
||||
rank_model: str = DEFAULT_RANK_MODEL,
|
||||
use_rank: bool = DEFAULT_USE_RANK,
|
||||
chat_model: str = DEFAULT_CHAT_MODEL):
|
||||
self.ollama_url = ollama_url
|
||||
self.qdrant_host = qdrant_host
|
||||
self.qdrant_port = qdrant_port
|
||||
self.chat_model = chat_model
|
||||
self.emb_model = SentenceTransformer(embed_model)
|
||||
self.prompt = ""
|
||||
self.qdrant = QdrantClient(host=args.qdrant_host, port=args.qdrant_port)
|
||||
self.use_rank = use_rank
|
||||
if self.use_rank:
|
||||
self.rank_model = CrossEncoder(rank_model)
|
||||
self.conversation_history = []
|
||||
|
||||
self.load_chat_model()
|
||||
|
||||
def load_chat_model(self):
|
||||
@@ -27,111 +52,73 @@ class RagSystem:
|
||||
body = {"model": self.chat_model}
|
||||
requests.post(url, json=body, timeout=600)
|
||||
|
||||
def search_qdrant(self, query: str, top_k: int = 6, qdrant_collection="rag"):
|
||||
def search_qdrant(self, query: str, doc_count: int = DEFAULT_TOP_K, collection_name = DEFAULT_QDRANT_COLLECTION):
|
||||
query_vec = self.emb_model.encode(query, show_progress_bar=False).tolist()
|
||||
url = f"http://{self.qdrant_host}:{self.qdrant_port}/collections/{qdrant_collection}/points/search"
|
||||
payload = {
|
||||
"vector": query_vec,
|
||||
"top": top_k,
|
||||
"with_payload": True,
|
||||
# "score_threshold": 0.6
|
||||
}
|
||||
resp = requests.post(url, json=payload)
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(f"> Ошибка qdrant: {resp.status_code} {resp.text}")
|
||||
results = resp.json().get("result", [])
|
||||
return results
|
||||
results = self.qdrant.query_points(
|
||||
collection_name=collection_name,
|
||||
query=query_vec,
|
||||
limit=doc_count,
|
||||
# score_threshold=0.5,
|
||||
)
|
||||
docs = []
|
||||
for point in results.points:
|
||||
docs.append({
|
||||
"payload": point.payload,
|
||||
"score": point.score,
|
||||
})
|
||||
return docs
|
||||
|
||||
def prepare_sources(self, context_docs: list):
|
||||
sources = ""
|
||||
for idx, doc in enumerate(context_docs, start=1):
|
||||
text = doc['payload'].get("text", "").strip()
|
||||
sources = f"{sources}\n<source id=\"{idx}\">\n{text}\n</source>\n"
|
||||
return sources
|
||||
def rank_documents(self, query: str, documents: list, top_n: int = DEFAULT_TOP_N):
|
||||
if not self.use_rank:
|
||||
return documents
|
||||
|
||||
def prepare_prompt(self, query: str, context_docs: list):
|
||||
sources = self.prepare_sources(context_docs)
|
||||
if os.path.exists('sys_prompt.txt'):
|
||||
with open('sys_prompt.txt', 'r') as fp:
|
||||
prompt_template = fp.read()
|
||||
return prompt_template.replace("{{sources}}", sources).replace("{{query}}", query)
|
||||
else:
|
||||
return f"""### Your role
|
||||
You are a helpful assistant that can answer questions based on the provided sources.
|
||||
pairs = [[query, doc["payload"]["text"]] for doc in documents]
|
||||
scores = self.rank_model.predict(pairs)
|
||||
|
||||
### Your user
|
||||
User is a human who is asking a question related to the provided sources.
|
||||
for i, doc in enumerate(documents):
|
||||
doc["rank_score"] = float(scores[i])
|
||||
|
||||
### Your task
|
||||
Please provide an answer based solely on the provided sources and the conversation history.
|
||||
documents.sort(key=lambda x: x['rank_score'], reverse=True)
|
||||
return documents[:top_n]
|
||||
|
||||
### Rules
|
||||
- You **MUST** respond in the SAME language as the user's query.
|
||||
- If uncertain, you **MUST** the user for clarification.
|
||||
- If there are no sources in context, you **MUST** clearly state that.
|
||||
- If none of the sources are helpful, you **MUST** clearly state that.
|
||||
- If you are unsure about the answer, you **MUST** clearly state that.
|
||||
- If the context is unreadable or of poor quality, you **MUST** inform the user and provide the best possible answer.
|
||||
- When referencing information from a source, you **MUST** cite the appropriate source(s) using their corresponding numbers.
|
||||
- **Only include inline citations using [id] (e.g., [1], [2]) when the <source> tag includes an id attribute.**
|
||||
- You NEVER MUST NOT add <source> or any XML/HTML tags in your response.
|
||||
- You NEVER MUST NOT cite if the <source> tag does not contain an id attribute.
|
||||
- Every answer MAY include at least one source citation.
|
||||
- Only cite a source when you are explicitly referencing it.
|
||||
- You may also cite multiple sources if they are all relevant to the question.
|
||||
- Ensure citations are concise and directly related to the information provided.
|
||||
- You CAN format your responses using Markdown.
|
||||
|
||||
### Example of sources list:
|
||||
|
||||
```
|
||||
<source id="1">The sky is red in the evening and blue in the morning.</source>
|
||||
<source id="2">Water is wet when the sky is red.</source>
|
||||
<query>When is water wet?</query>
|
||||
```
|
||||
Response:
|
||||
```
|
||||
Water will be wet when the sky is red [2], which occurs in the evening [1].
|
||||
```
|
||||
|
||||
### Now let's start!
|
||||
|
||||
```
|
||||
{sources}
|
||||
<query>{query}</query>
|
||||
```
|
||||
|
||||
Respond."""
|
||||
|
||||
def generate_answer(self, prompt: str):
|
||||
def generate_answer(self, sys_prompt: str, user_prompt: str):
|
||||
url = f"{self.ollama_url}/api/generate"
|
||||
body = {
|
||||
"model": self.chat_model,
|
||||
"prompt": prompt,
|
||||
"messages": self.conversation_history,
|
||||
"system": sys_prompt,
|
||||
"prompt": user_prompt,
|
||||
#"context": self.conversation_history,
|
||||
"stream": False,
|
||||
# "options": {
|
||||
# "temperature": 0.4,
|
||||
# "top_p": 0.1,
|
||||
# },
|
||||
"options": {
|
||||
"temperature": 0.5,
|
||||
# "top_p": 0.2,
|
||||
},
|
||||
}
|
||||
self.response = requests.post(url, json=body, timeout=900)
|
||||
if self.response.status_code != 200:
|
||||
return f"Ошибка генерации ответа: {self.response.status_code} {self.response.text}"
|
||||
return self.response.json().get("response", "").strip()
|
||||
|
||||
def generate_answer_stream(self, prompt: str):
|
||||
response = requests.post(url, json=body, timeout=900)
|
||||
if response.status_code != 200:
|
||||
return f"Ошибка генерации ответа: {response.status_code} {response.text}"
|
||||
self.response = response.json()
|
||||
return self.response["response"]
|
||||
|
||||
def generate_answer_stream(self, sys_prompt: str, user_prompt: str):
|
||||
url = f"{self.ollama_url}/api/generate"
|
||||
body = {
|
||||
"model": self.chat_model,
|
||||
"prompt": prompt,
|
||||
"messages": self.conversation_history,
|
||||
"stream": True
|
||||
"system": sys_prompt,
|
||||
"prompt": user_prompt,
|
||||
#"context": self.conversation_history,
|
||||
"stream": True,
|
||||
"options": {
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.2,
|
||||
},
|
||||
}
|
||||
resp = requests.post(url, json=body, stream=True, timeout=900)
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}")
|
||||
full_answer = ""
|
||||
|
||||
answer = ""
|
||||
for chunk in resp.iter_lines():
|
||||
if chunk:
|
||||
try:
|
||||
@@ -139,39 +126,42 @@ Respond."""
|
||||
data = json.loads(decoded_chunk)
|
||||
if "response" in data:
|
||||
yield data["response"]
|
||||
full_answer += data["response"]
|
||||
elif "error" in data:
|
||||
print(f"Stream error: {data['error']}")
|
||||
answer += data["response"]
|
||||
if "done" in data and data["done"] is True:
|
||||
self.response = data
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
print(f"Could not decode JSON from chunk: {chunk.decode('utf-8')}")
|
||||
elif "error" in data:
|
||||
answer += f" | Ошибка стриминга ответа: {data['error']}"
|
||||
break
|
||||
except json.JSONDecodeError as e:
|
||||
answer += f" | Ошибка конвертации чанка: {chunk.decode('utf-8')} - {e}"
|
||||
except Exception as e:
|
||||
print(f"Error processing chunk: {e}")
|
||||
answer += f" | Ошибка обработки чанка: {e}"
|
||||
|
||||
def get_prompt_eval_count(self):
|
||||
if not self.response:
|
||||
if not self.response["prompt_eval_count"]:
|
||||
return 0
|
||||
return self.response.json().get("prompt_eval_count", 0)
|
||||
return self.response["prompt_eval_count"]
|
||||
|
||||
def get_prompt_eval_duration(self):
|
||||
if not self.response:
|
||||
if not self.response["prompt_eval_duration"]:
|
||||
return 0
|
||||
return self.response.json().get("prompt_eval_duration", 0) / (10 ** 9)
|
||||
return self.response["prompt_eval_duration"] / (10 ** 9)
|
||||
|
||||
def get_eval_count(self):
|
||||
if not self.response:
|
||||
if not self.response["eval_count"]:
|
||||
return 0
|
||||
return self.response.json().get("eval_count", 0)
|
||||
return self.response["eval_count"]
|
||||
|
||||
def get_eval_duration(self):
|
||||
if not self.response:
|
||||
if not self.response["eval_duration"]:
|
||||
return 0
|
||||
return self.response.json().get("eval_duration", 0) / (10 ** 9)
|
||||
return self.response["eval_duration"] / (10 ** 9)
|
||||
|
||||
def get_total_duration(self):
|
||||
if not self.response:
|
||||
if not self.response["total_duration"]:
|
||||
return 0
|
||||
return self.response.json().get("total_duration", 0) / (10 ** 9)
|
||||
return self.response["total_duration"] / (10 ** 9)
|
||||
|
||||
def get_tps(self):
|
||||
eval_count = self.get_eval_count()
|
||||
@@ -180,202 +170,318 @@ Respond."""
|
||||
return 0
|
||||
return eval_count / eval_duration
|
||||
|
||||
def print_sources(context_docs: list):
|
||||
print("\n\nИсточники:")
|
||||
for idx, doc in enumerate(context_docs, start=1):
|
||||
title = doc['payload'].get("filename", None)
|
||||
url = doc['payload'].get("url", None)
|
||||
date = doc['payload'].get("date", None)
|
||||
version = doc['payload'].get("version", None)
|
||||
author = doc['payload'].get("author", None)
|
||||
class App:
|
||||
def __init__(
|
||||
self,
|
||||
args: list = []
|
||||
):
|
||||
if not args.query and not args.interactive:
|
||||
print("Ошибка: укажите запрос (--query) и/или используйте интерактивный режим (--interactive)")
|
||||
sys.exit(1)
|
||||
|
||||
if url is None:
|
||||
url = "(нет веб-ссылки)"
|
||||
if date is None:
|
||||
date = "(неизвестно)"
|
||||
if version is None:
|
||||
version = "0"
|
||||
if author is None:
|
||||
author = "(неизвестен)"
|
||||
self.args = args
|
||||
self.print_v(text=f"Включить интерактивный режим диалога: {args.interactive}")
|
||||
self.print_v(text=f"Включить потоковый вывод: {args.stream}")
|
||||
if self.is_custom_sys_prompt():
|
||||
self.print_v(text=f"Системный промпт: {args.sys_prompt}")
|
||||
else:
|
||||
self.print_v(text=f"Системный промпт: по умолчанию")
|
||||
self.print_v(text=f"Показать сист. промпт перед запросом: {args.show_prompt}")
|
||||
self.print_v(text=f"Выводить служебные сообщения: {args.verbose}")
|
||||
self.print_v(text=f"Выводить статистику об ответе: {args.show_stats}")
|
||||
self.print_v(text=f"Адрес хоста Qdrant: {args.qdrant_host}")
|
||||
self.print_v(text=f"Номер порта Qdrant: {args.qdrant_port}")
|
||||
self.print_v(text=f"Название коллекции для поиска документов: {args.qdrant_collection}")
|
||||
self.print_v(text=f"Ollama API URL: {args.ollama_url}")
|
||||
self.print_v(text=f"Модель генерации Ollama: {args.chat_model}")
|
||||
self.print_v(text=f"Модель эмбеддинга: {args.emb_model}")
|
||||
self.print_v(text=f"Количество документов для поиска: {args.topk}")
|
||||
self.print_v(text=f"Включить ранжирование: {args.use_rank}")
|
||||
self.print_v(text=f"Модель ранжирования: {args.rank_model}")
|
||||
self.print_v(text=f"Количество документов после ранжирования: {args.topn}")
|
||||
self.init_rag()
|
||||
|
||||
print(f"{idx}. {title}")
|
||||
print(f" {url} (v{version} {author})")
|
||||
print(f" актуальность на {date}")
|
||||
def print_v(self, text: str = "\n"):
|
||||
if self.args.verbose:
|
||||
print(f"{text}")
|
||||
|
||||
def print_v(text: str, is_verbose: bool):
|
||||
if is_verbose:
|
||||
print(text)
|
||||
def init_rag(self):
|
||||
self.print_v(text="\nИнициализация моделей...")
|
||||
self.rag = RagSystem(
|
||||
ollama_url = self.args.ollama_url,
|
||||
qdrant_host = self.args.qdrant_host,
|
||||
qdrant_port = self.args.qdrant_port,
|
||||
embed_model = self.args.emb_model,
|
||||
rank_model = self.args.rank_model,
|
||||
use_rank = self.args.use_rank,
|
||||
chat_model = self.args.chat_model
|
||||
)
|
||||
self.print_v(text=f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG")
|
||||
|
||||
def print_stats(rag: RagSystem):
|
||||
print("\n\nСтатистика:")
|
||||
print(f"* Time: {rag.get_total_duration()}s")
|
||||
print(f"* TPS: {rag.get_tps()}")
|
||||
print(f"* PEC: {rag.get_prompt_eval_count()}")
|
||||
print(f"* PED: {rag.get_prompt_eval_duration()}s")
|
||||
print(f"* EC: {rag.get_eval_count()}")
|
||||
print(f"* ED: {rag.get_eval_duration()}s\n")
|
||||
def init_query(self):
|
||||
self.query = None
|
||||
if args.interactive:
|
||||
self.print_v(text="\nИНТЕРАКТИВНЫЙ РЕЖИМ")
|
||||
self.print_v(text="Можете вводить запрос (или 'exit' для выхода)\n")
|
||||
|
||||
def main():
|
||||
import sys
|
||||
if self.args.query:
|
||||
self.query = self.args.query.strip()
|
||||
print(f">>> {self.query}")
|
||||
elif args.interactive:
|
||||
self.query = input(">>> ").strip()
|
||||
|
||||
def process_help(self):
|
||||
print("<<< Команды итерактивного режима:")
|
||||
print("save -- сохранить диалог в файл")
|
||||
print("exit -- выход\n")
|
||||
self.query = None
|
||||
self.args.query = None
|
||||
|
||||
def process_save(self):
|
||||
import datetime
|
||||
timestamp = int(time.time())
|
||||
dt = datetime.datetime.fromtimestamp(timestamp).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
filename = f"chats/chat-{timestamp}-{self.args.chat_model}.md"
|
||||
|
||||
markdown_content = f"# История диалога от {dt}\n\n"
|
||||
markdown_content += f"## Параметры диалога\n"
|
||||
markdown_content += f"```\nargs = {self.args}\n```\n"
|
||||
markdown_content += f"```\nemb_model = {self.rag.emb_model}\n```\n"
|
||||
markdown_content += f"```\nrank_model = {self.rag.rank_model}\n```\n"
|
||||
|
||||
for entry in self.rag.conversation_history:
|
||||
if entry['role'] == 'user':
|
||||
markdown_content += f"## Пользователь\n\n"
|
||||
elif entry['role'] == 'assistant':
|
||||
markdown_content += f"## Модель\n\n"
|
||||
docs = self.rag.prepare_ctx_sources(entry['docs']).replace("```", "")
|
||||
markdown_content += f"```\n{docs}\n```\n\n"
|
||||
markdown_content += f"{entry['content']}\n\n"
|
||||
|
||||
os.makedirs('chats', exist_ok=True)
|
||||
with open(filename, 'w') as fp:
|
||||
fp.write(markdown_content)
|
||||
|
||||
print(f"<<< Диалог сохранён в файл: {filename}\n")
|
||||
self.query = None
|
||||
|
||||
def find_docs(self, query: str, top_k: int, collection_name: str):
|
||||
self.print_v(text="\nПоиск документов...")
|
||||
context_docs = self.rag.search_qdrant(query, top_k, collection_name)
|
||||
self.print_v(text=f"Найдено {len(context_docs)} документов")
|
||||
return context_docs
|
||||
|
||||
def rank_docs(self, docs: list = [], top_n = DEFAULT_TOP_N):
|
||||
self.print_v(text="\nРанжирование документов...")
|
||||
ranked_docs = self.rag.rank_documents(self.query, docs, top_n)
|
||||
self.print_v(text=f"После ранжирования осталось {len(ranked_docs)} документов")
|
||||
return ranked_docs
|
||||
|
||||
def prepare_ctx_sources(self, docs: list):
|
||||
sources = ""
|
||||
for idx, doc in enumerate(docs, start=1):
|
||||
text = doc['payload'].get("text", "").strip()
|
||||
sources = f"{sources}\n<source id=\"{idx}\">\n{text}\n</source>\n"
|
||||
return sources
|
||||
|
||||
def prepare_cli_sources(self, docs: list):
|
||||
sources = "\nИсточники:\n"
|
||||
for idx, doc in enumerate(docs, start=1):
|
||||
title = doc['payload'].get("filename", None)
|
||||
url = doc['payload'].get("url", None)
|
||||
date = doc['payload'].get("date", None)
|
||||
version = doc['payload'].get("version", None)
|
||||
author = doc['payload'].get("author", None)
|
||||
|
||||
if url is None:
|
||||
url = "(нет веб-ссылки)"
|
||||
if date is None:
|
||||
date = "(неизвестно)"
|
||||
if version is None:
|
||||
version = "0"
|
||||
if author is None:
|
||||
author = "(неизвестен)"
|
||||
|
||||
sources += f"{idx}. {title}\n"
|
||||
sources += f" {url}\n"
|
||||
sources += f" Версия {version} от {author}, актуальная на {date}\n"
|
||||
if doc['rank_score']:
|
||||
sources += f" score = {doc['score']} | rank_score = {doc['rank_score']}\n"
|
||||
else:
|
||||
sources += f" score = {doc['score']}\n"
|
||||
return sources
|
||||
|
||||
def prepare_sys_prompt(self, query: str, docs: list):
|
||||
if self.is_custom_sys_prompt():
|
||||
with open(self.args.sys_prompt, 'r') as fp:
|
||||
prompt_tpl = fp.read()
|
||||
else:
|
||||
prompt_tpl = """You are a helpful assistant that can answer questions based on the provided context.
|
||||
Your user is the person asking the source-related question.
|
||||
Your job is to answer the question based on the context alone.
|
||||
If the context doesn't provide much information, answer "I don't know."
|
||||
Adhere to this in all languages.
|
||||
|
||||
Context:
|
||||
|
||||
-----------------------------------------
|
||||
{{sources}}
|
||||
-----------------------------------------
|
||||
"""
|
||||
|
||||
sources = self.prepare_ctx_sources(docs)
|
||||
return prompt_tpl.replace("{{sources}}", sources).replace("{{query}}", query)
|
||||
|
||||
def show_prompt(self, sys_prompt: str):
|
||||
print("\n================ Системный промпт ==================")
|
||||
print(f"{sys_prompt}\n============ Конец системного промпта ==============\n")
|
||||
|
||||
def process_query(self, sys_prompt: str, user_prompt: str, streaming: bool = DEFAULT_STREAM):
|
||||
answer = ""
|
||||
# try:
|
||||
if streaming:
|
||||
self.print_v(text="\nГенерация потокового ответа (^C для остановки)...\n")
|
||||
print(f"<<< ", end='', flush=True)
|
||||
for token in self.rag.generate_answer_stream(sys_prompt, user_prompt):
|
||||
answer += token
|
||||
print(token, end='', flush=True)
|
||||
else:
|
||||
self.print_v(text="\nГенерация ответа (^C для остановки)...\n")
|
||||
answer = self.rag.generate_answer(sys_prompt, user_prompt)
|
||||
print(f"<<< {answer}\n")
|
||||
# except RuntimeError as e:
|
||||
# answer = str(e)
|
||||
|
||||
print(f"\n===================================================")
|
||||
return answer
|
||||
|
||||
def is_custom_sys_prompt(self):
|
||||
return self.args.sys_prompt and os.path.exists(self.args.sys_prompt)
|
||||
|
||||
def print_stats(self):
|
||||
print(f"* Time: {self.rag.get_total_duration()}s")
|
||||
print(f"* TPS: {self.rag.get_tps()}")
|
||||
print(f"* PEC: {self.rag.get_prompt_eval_count()}")
|
||||
print(f"* PED: {self.rag.get_prompt_eval_duration()}s")
|
||||
print(f"* EC: {self.rag.get_eval_count()}")
|
||||
print(f"* ED: {self.rag.get_eval_duration()}s\n")
|
||||
self.query = None
|
||||
self.args.query = None
|
||||
|
||||
def process(self):
|
||||
while True:
|
||||
try:
|
||||
self.init_query()
|
||||
|
||||
if not self.query or self.query == "":
|
||||
continue
|
||||
|
||||
if self.query.lower() == "help":
|
||||
self.process_help()
|
||||
continue
|
||||
|
||||
if self.query.strip().lower() == "save":
|
||||
self.process_save()
|
||||
continue
|
||||
|
||||
if self.query.strip().lower() == "stats":
|
||||
print("\n<<< Статистика:")
|
||||
self.print_stats()
|
||||
continue
|
||||
|
||||
if self.query.strip().lower() == "exit":
|
||||
self.print_v(text="\n*** Завершение работы")
|
||||
sys.exit(0)
|
||||
|
||||
context_docs = self.find_docs(self.query, self.args.topk, self.args.qdrant_collection)
|
||||
if not context_docs:
|
||||
if args.interactive:
|
||||
print("<<< Релевантные документы не найдены")
|
||||
self.query = None
|
||||
self.args.query = None
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
ranked_docs = self.rank_docs(context_docs, self.args.topn)
|
||||
if not ranked_docs:
|
||||
if args.interactive:
|
||||
print("<<< Релевантные документы были отсеяны полностью")
|
||||
self.query = None
|
||||
self.args.query = None
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
sys_prompt = self.prepare_sys_prompt(self.query, ranked_docs)
|
||||
if self.args.show_prompt:
|
||||
self.show_prompt(sys_prompt)
|
||||
|
||||
try:
|
||||
answer = self.process_query(sys_prompt, self.query, self.args.stream)
|
||||
except KeyboardInterrupt:
|
||||
print("\n*** Генерация ответа прервана")
|
||||
self.query = None
|
||||
self.args.query = None
|
||||
print(self.prepare_cli_sources(ranked_docs))
|
||||
if self.args.show_stats:
|
||||
print("\nСтатистика:")
|
||||
self.print_stats()
|
||||
continue
|
||||
|
||||
print(self.prepare_cli_sources(ranked_docs))
|
||||
|
||||
if self.args.show_stats:
|
||||
print("\nСтатистика:")
|
||||
self.print_stats()
|
||||
|
||||
self.rag.conversation_history.append({
|
||||
"role": "user",
|
||||
"content": self.query,
|
||||
})
|
||||
|
||||
self.rag.conversation_history.append({
|
||||
"role": "assistant",
|
||||
"docs": ranked_docs,
|
||||
"content": answer,
|
||||
})
|
||||
|
||||
if args.interactive:
|
||||
self.query = None
|
||||
self.args.query = None
|
||||
else:
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n*** Завершение работы")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"Ошибка: {e}")
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="RAG-система с использованием Ollama и Qdrant")
|
||||
parser.add_argument("--query", type=str, help="Запрос к RAG")
|
||||
parser.add_argument("--interactive", default=False, action=argparse.BooleanOptionalAction, help="Перейти в интерактивный режим диалога")
|
||||
parser.add_argument("--show-prompt", default=False, action=argparse.BooleanOptionalAction, help="Показать полный промпт перед обработкой запроса")
|
||||
parser.add_argument("--qdrant-host", default="localhost", help="Qdrant host")
|
||||
parser.add_argument("--qdrant-port", type=int, default=6333, help="Qdrant port")
|
||||
parser.add_argument("--qdrant-collection", type=str, default="rag", help="Название коллекции для поиска документов")
|
||||
parser.add_argument("--ollama-url", default="http://localhost:11434", help="Ollama API URL")
|
||||
parser.add_argument("--emb-model", default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", help="Модель эмбеддинга")
|
||||
parser.add_argument("--chat-model", default="phi4-mini:3.8b", help="Модель генерации Ollama")
|
||||
parser.add_argument("--topk", type=int, default=6, help="Количество документов для поиска")
|
||||
parser.add_argument("--verbose", default=False, action=argparse.BooleanOptionalAction, help="Выводить промежуточные служебные сообщения")
|
||||
parser.add_argument("--show-stats", default=False, action=argparse.BooleanOptionalAction, help="Выводить статистику об ответе (не работает с --stream)")
|
||||
parser.add_argument("--stream", default=False, action=argparse.BooleanOptionalAction, help="Выводить статистику об ответе")
|
||||
parser.add_argument("--interactive", default=DEFAULT_INTERACTIVE, action=argparse.BooleanOptionalAction, help="Включить интерактивный режим диалога")
|
||||
parser.add_argument("--stream", default=DEFAULT_STREAM, action=argparse.BooleanOptionalAction, help="Включить потоковый вывод")
|
||||
parser.add_argument("--sys-prompt", type=str, help="Путь к файлу шаблона системного промпта")
|
||||
parser.add_argument("--show-prompt", default=DEFAULT_SHOW_PROMPT, action=argparse.BooleanOptionalAction, help="Показать сист. промпт перед запросом")
|
||||
parser.add_argument("--verbose", default=DEFAULT_VERBOSE, action=argparse.BooleanOptionalAction, help="Выводить служебные сообщения")
|
||||
parser.add_argument("--show-stats", default=DEFAULT_SHOW_STATS, action=argparse.BooleanOptionalAction, help="Выводить статистику об ответе (не работает с --stream)")
|
||||
parser.add_argument("--qdrant-host", default=DEFAULT_QDRANT_HOST, help="Адрес хоста Qdrant")
|
||||
parser.add_argument("--qdrant-port", type=int, default=DEFAULT_QDRANT_PORT, help="Номер порта Qdrant")
|
||||
parser.add_argument("--qdrant-collection", type=str, default=DEFAULT_QDRANT_COLLECTION, help="Название коллекции для поиска документов")
|
||||
parser.add_argument("--ollama-url", default=DEFAULT_OLLAMA_URL, help="Ollama API URL")
|
||||
parser.add_argument("--chat-model", default=DEFAULT_CHAT_MODEL, help="Модель генерации Ollama")
|
||||
parser.add_argument("--emb-model", default=DEFAULT_EMBED_MODEL, help="Модель эмбеддинга")
|
||||
parser.add_argument("--topk", type=int, default=DEFAULT_TOP_K, help="Количество документов для поиска")
|
||||
parser.add_argument("--use-rank", default=DEFAULT_USE_RANK, action=argparse.BooleanOptionalAction, help="Включить ранжирование")
|
||||
parser.add_argument("--rank-model", type=str, default=DEFAULT_RANK_MODEL, help="Модель ранжирования")
|
||||
parser.add_argument("--topn", type=int, default=DEFAULT_TOP_N, help="Количество документов после ранжирования")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.query and not args.interactive:
|
||||
print("Ошибка: укажите запрос (--query) и/или используйте интерактивный режим (--interactive)")
|
||||
sys.exit(1)
|
||||
|
||||
print_v(f"Адрес ollama: {args.ollama_url}", args.verbose)
|
||||
print_v(f"Адрес qdrant: {args.qdrant_host}:{args.qdrant_port}", args.verbose)
|
||||
print_v(f"Модель эмбеддинга: {args.emb_model}", args.verbose)
|
||||
print_v(f"Модель чата: {args.chat_model}", args.verbose)
|
||||
print_v(f"Документов для поиска: {args.topk}", args.verbose)
|
||||
print_v(f"Коллекция для поиска: {args.qdrant_collection}", args.verbose)
|
||||
if os.path.exists('sys_prompt.txt'):
|
||||
print_v("Будет использоваться sys_prompt.txt!", args.verbose)
|
||||
|
||||
print_v("\nПервая инициализация моделей...", args.verbose)
|
||||
rag = RagSystem(
|
||||
ollama_url=args.ollama_url,
|
||||
qdrant_host=args.qdrant_host,
|
||||
qdrant_port=args.qdrant_port,
|
||||
embed_model=args.emb_model,
|
||||
chat_model=args.chat_model
|
||||
)
|
||||
print_v(f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG", args.verbose)
|
||||
|
||||
query = None
|
||||
if args.interactive:
|
||||
print_v("\nИНТЕРАКТИВНЫЙ РЕЖИМ", args.verbose)
|
||||
print_v("Можете вводить запрос (или 'exit' для выхода)\n", args.verbose)
|
||||
|
||||
if args.query:
|
||||
query = args.query.strip()
|
||||
print(f">>> {query}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
if not query or query == "":
|
||||
query = input(">>> ").strip()
|
||||
|
||||
if not query or query == "":
|
||||
continue
|
||||
|
||||
if query.lower() == "help":
|
||||
print("<<< Команды итерактивного режима:")
|
||||
print("save -- сохранить диалог в файл")
|
||||
print("stats -- вывести статистику последнего ответа")
|
||||
print("exit -- выход\n")
|
||||
query = None
|
||||
continue
|
||||
|
||||
if query.strip().lower() == "save":
|
||||
import datetime
|
||||
timestamp = int(time.time())
|
||||
dt = datetime.datetime.fromtimestamp(timestamp).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
filename = f"chats/chat-{timestamp}.md"
|
||||
|
||||
markdown_content = f"# История диалога от {dt}\n\n"
|
||||
markdown_content += f"## Параметры диалога\n"
|
||||
markdown_content += f"```\nargs = {args}\n```\n"
|
||||
markdown_content += f"```\nemb_model = {rag.emb_model}\n```\n"
|
||||
|
||||
for entry in rag.conversation_history:
|
||||
if entry['role'] == 'user':
|
||||
markdown_content += f"## Пользователь\n\n"
|
||||
elif entry['role'] == 'assistant':
|
||||
markdown_content += f"## Модель\n\n"
|
||||
docs = rag.prepare_sources(entry['docs']).replace("```", "")
|
||||
markdown_content += f"```\n{docs}\n```\n\n"
|
||||
markdown_content += f"{entry['content']}\n\n"
|
||||
|
||||
os.makedirs('chats', exist_ok=True)
|
||||
with open(filename, 'w') as fp:
|
||||
fp.write(markdown_content)
|
||||
|
||||
print(f"<<< Диалог сохранён в файл: {filename}\n")
|
||||
query = None
|
||||
continue
|
||||
|
||||
if query.strip().lower() == "exit":
|
||||
print_v("\n*** Завершение работы", args.verbose)
|
||||
break
|
||||
|
||||
print_v("\nПоиск релевантных документов...", args.verbose)
|
||||
context_docs = rag.search_qdrant(query, top_k=args.topk, qdrant_collection=args.qdrant_collection)
|
||||
if not context_docs:
|
||||
print("<<< Релевантные документы не найдены")
|
||||
if args.interactive:
|
||||
query = None
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
print_v(f"Найдено {len(context_docs)} релевантных документов", args.verbose)
|
||||
# print_sources(context_docs)
|
||||
|
||||
prompt = rag.prepare_prompt(query=query, context_docs=context_docs)
|
||||
if args.show_prompt:
|
||||
print("\nПолный системный промпт: --------------------------")
|
||||
print(f"{prompt}\n---------------------------------------------------")
|
||||
|
||||
print_v("\nГенерация ответа...\n", args.verbose)
|
||||
|
||||
if args.stream:
|
||||
answer = "\n<<< "
|
||||
print(answer, end='', flush=True)
|
||||
try:
|
||||
for message_part in rag.generate_answer_stream(prompt):
|
||||
answer += message_part
|
||||
print(message_part, end='', flush=True)
|
||||
except RuntimeError as e:
|
||||
answer = str(e)
|
||||
print(f"\n{answer}\n===================================================\n")
|
||||
else:
|
||||
answer = rag.generate_answer(prompt)
|
||||
print(f"<<< {answer}\n")
|
||||
|
||||
print_sources(context_docs)
|
||||
if args.show_stats and not args.stream:
|
||||
print_stats(rag)
|
||||
|
||||
rag.conversation_history.append({
|
||||
"role": "user",
|
||||
"content": query,
|
||||
})
|
||||
|
||||
rag.conversation_history.append({
|
||||
"role": "assistant",
|
||||
"docs": context_docs,
|
||||
"content": answer,
|
||||
})
|
||||
|
||||
if args.interactive:
|
||||
query = None
|
||||
else:
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n*** Завершение работы")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"Ошибка: {e}")
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
app = App(args)
|
||||
app.process()
|
||||
|
||||
@@ -4,6 +4,7 @@ from sentence_transformers import SentenceTransformer
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.text_splitter import MarkdownHeaderTextSplitter
|
||||
|
||||
DEFAULT_INPUT_DIR="data"
|
||||
DEFAULT_CHUNK_SIZE=500
|
||||
@@ -59,24 +60,45 @@ def load_markdown_files(input_dir):
|
||||
return documents
|
||||
|
||||
def chunk_text(texts, chunk_size, chunk_overlap):
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
markdown_splitter = MarkdownHeaderTextSplitter(
|
||||
headers_to_split_on=[
|
||||
("#", "Header 1"),
|
||||
("##", "Header 2"),
|
||||
("###", "Header 3"),
|
||||
],
|
||||
strip_headers=False,
|
||||
return_each_line=False,
|
||||
)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
add_start_index=True,
|
||||
length_function=len,
|
||||
separators=["\n\n", "\n", " ", ""]
|
||||
)
|
||||
|
||||
chunks = []
|
||||
for doc in texts:
|
||||
doc_chunks = splitter.split_text(doc["text"])
|
||||
for i, chunk in enumerate(doc_chunks):
|
||||
chunk_id = f"{doc['id']}_chunk{i}"
|
||||
chunk_dict = {"id": chunk_id, "text": chunk}
|
||||
md_header_splits = markdown_splitter.split_text(doc["text"])
|
||||
|
||||
# Перенос всех доступных метаданных
|
||||
for key in ["url", "version", "author", "date"]:
|
||||
if key in doc and doc[key] is not None:
|
||||
chunk_dict[key] = doc[key]
|
||||
chunks.append(chunk_dict)
|
||||
for md_split in md_header_splits:
|
||||
# RecursiveCharacterTextSplitter for each markdown split
|
||||
split_docs = text_splitter.split_documents([md_split])
|
||||
|
||||
for i, chunk in enumerate(split_docs):
|
||||
chunk_id = f"{doc['id']}_chunk{i}"
|
||||
chunk_dict = {"id": chunk_id, "text": chunk.page_content}
|
||||
|
||||
# Перенос всех доступных метаданных, включая метаданные из MarkdownHeaderTextSplitter
|
||||
for key in ["url", "version", "author", "date"]:
|
||||
if key in doc and doc[key] is not None:
|
||||
chunk_dict[key] = doc[key]
|
||||
|
||||
# Добавление метаданных из MarkdownHeaderTextSplitter
|
||||
for key, value in chunk.metadata.items():
|
||||
chunk_dict[key] = value
|
||||
|
||||
chunks.append(chunk_dict)
|
||||
return chunks
|
||||
|
||||
def embed_and_upload(chunks, embedding_model_name, qdrant_host="localhost", qdrant_port=6333, qdrant_collection="rag"):
|
||||
@@ -149,3 +171,4 @@ if __name__ == "__main__":
|
||||
args.qdrant_port,
|
||||
args.qdrant_collection
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user