1
0
Files
ollama/rag/vectorize.py
AnthonyAxenov f3672e6ffd Много мелких доработок
- переименован input_md => data
- добавление инфы о дате, версии и авторе изменений conf-страницы в индекс
- вывод этой инфы в источниках
- вывод статистики последнего ответа
- указание имени коллекции для qdrant
- мелочи по текстовкам
2025-08-29 08:54:43 +08:00

139 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import argparse
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.http import models
from langchain.text_splitter import RecursiveCharacterTextSplitter
def load_markdown_files(input_dir):
documents = []
for filename in os.listdir(input_dir):
if filename.endswith(".md"):
path = os.path.join(input_dir, filename)
with open(path, "r", encoding="utf-8") as f:
content = f.read()
lines = content.splitlines()
url = None
version = None
author = None
date = None
if lines:
# Проверка первой строки на URL
if lines[0].strip().startswith("@@") and lines[0].strip().endswith("@@") and len(lines[0].strip()) > 4:
url = lines[0].strip()[2:-2].strip()
lines = lines[1:]
# Проверка оставшихся строк на метаданные
i = 0
while i < len(lines):
line = lines[i].strip()
if line.startswith("^^") and line.endswith("^^") and len(line) > 4:
version = line[2:-2].strip()
lines.pop(i)
elif line.startswith("%%") and line.endswith("%%") and len(line) > 4:
author = line[2:-2].strip()
lines.pop(i)
elif line.startswith("==") and line.endswith("==") and len(line) > 4:
date = line[2:-2].strip()
lines.pop(i)
else:
i += 1
doc_metadata = {"id": filename, "text": "\n".join(lines)}
if url: doc_metadata["url"] = url
if version: doc_metadata["version"] = version
if author: doc_metadata["author"] = author
if date: doc_metadata["date"] = date
documents.append(doc_metadata)
return documents
def chunk_text(texts, chunk_size, chunk_overlap):
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
separators=["\n\n", "\n", " ", ""]
)
chunks = []
for doc in texts:
doc_chunks = splitter.split_text(doc["text"])
for i, chunk in enumerate(doc_chunks):
chunk_id = f"{doc['id']}_chunk{i}"
chunk_dict = {"id": chunk_id, "text": chunk}
# Перенос всех доступных метаданных
for key in ["url", "version", "author", "date"]:
if key in doc and doc[key] is not None:
chunk_dict[key] = doc[key]
chunks.append(chunk_dict)
return chunks
def embed_and_upload(chunks, embedding_model_name, qdrant_host="localhost", qdrant_port=6333, qdrant_collection="rag"):
import hashlib
print(f"Инициализация модели {args.embedding_model}")
embedder = SentenceTransformer(embedding_model_name)
print(f"Подключение к qdrant ({qdrant_host}:{qdrant_port})")
client = QdrantClient(host=qdrant_host, port=qdrant_port)
if client.collection_exists(qdrant_collection):
client.delete_collection(qdrant_collection)
client.create_collection(
collection_name=qdrant_collection,
vectors_config=models.VectorParams(size=embedder.get_sentence_embedding_dimension(), distance=models.Distance.COSINE),
)
points = []
total_chunks = len(chunks)
for idx, chunk in enumerate(chunks, start=1):
# Qdrant point IDs must be positive integers
id_hash = int(hashlib.sha256(chunk["id"].encode("utf-8")).hexdigest(), 16) % (10**16)
vector = embedder.encode(chunk["text"]).tolist()
points.append(models.PointStruct(
id=id_hash,
vector=vector,
payload={
"text": chunk["text"],
"filename": chunk["id"].rsplit(".md_chunk", 1)[0],
"url": chunk.get("url", None),
"version": chunk.get("version", None),
"author": chunk.get("author", None),
"date": chunk.get("date", None)
}
))
print(f"[{idx}/{total_chunks}] Подготовлен чанк: {chunk['id']} -> ID: {id_hash}")
batch_size = 100
for i in range(0, total_chunks, batch_size):
batch = points[i : i + batch_size]
client.upsert(collection_name=qdrant_collection, points=batch)
print(f"Записан батч {(i // batch_size) + 1}, содержащий {len(batch)} точек, всего записано: {min(i + batch_size, total_chunks)}/{total_chunks}")
print(f"Завершена запись всех {total_chunks} чанков в коллекцию '{qdrant_collection}'.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Скрипт векторизаци данных для Qdrant")
parser.add_argument("--input-dir", type=str, default="data", help="Директория с Markdown-файлами для чтения")
parser.add_argument("--chunk-size", type=int, default=500, help="Размер чанка")
parser.add_argument("--chunk-overlap", type=int, default=100, help="Размер перекрытия")
parser.add_argument("--embedding-model", type=str, default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", help="Модель эмбеддинга")
parser.add_argument("--qdrant-host", type=str, default="localhost", help="Адрес хоста Qdrant")
parser.add_argument("--qdrant-port", type=int, default=6333, help="Номер порта Qdrant")
parser.add_argument("--qdrant-collection", type=str, default="rag", help="Название коллекции для сохранения документов")
args = parser.parse_args()
documents = load_markdown_files(args.input_dir)
print(f"Найдено документов: {len(documents)}")
print(f"Подготовка чанков...")
chunks = chunk_text(documents, args.chunk_size, args.chunk_overlap)
print(f"Создано чанков: {len(chunks)} ({args.chunk_size}/{args.chunk_overlap})")
embed_and_upload(chunks, args.embedding_model, args.qdrant_host, args.qdrant_port, args.qdrant_collection)