1
0
Files
ollama/rag/vectorize.py
AnthonyAxenov a9328b4681 Почти полная переработка всего rag
- включение qdrant в контур
- использование нормальной эмб-модели
- векторизация текста
- README и туча мелочей
2025-08-25 01:55:46 +08:00

107 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import argparse
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.http import models
from langchain.text_splitter import RecursiveCharacterTextSplitter
def load_markdown_files(input_dir):
documents = []
for filename in os.listdir(input_dir):
if filename.endswith(".md"):
path = os.path.join(input_dir, filename)
with open(path, "r", encoding="utf-8") as f:
content = f.read()
lines = content.splitlines()
url = None
if lines:
first_line = lines[0].strip()
if first_line.startswith("@@") and first_line.endswith("@@") and len(first_line) > 4:
url = first_line[2:-2].strip()
content = "\n".join(lines[1:]) # Remove the first line from content
documents.append({"id": filename, "text": content, "url": url})
return documents
def chunk_text(texts, chunk_size, chunk_overlap):
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
separators=["\n\n", "\n", " ", ""]
)
chunks = []
for doc in texts:
doc_chunks = splitter.split_text(doc["text"])
for i, chunk in enumerate(doc_chunks):
chunk_id = f"{doc['id']}_chunk{i}"
chunk_dict = {"id": chunk_id, "text": chunk}
if "url" in doc and doc["url"] is not None:
chunk_dict["url"] = doc["url"]
chunks.append(chunk_dict)
return chunks
def embed_and_upload(chunks, embedding_model_name, qdrant_host="localhost", qdrant_port=6333):
import hashlib
print(f"Инициализация модели {args.embedding_model}")
embedder = SentenceTransformer(embedding_model_name)
print(f"Подключение к qdrant ({qdrant_host}:{qdrant_port})")
client = QdrantClient(host=qdrant_host, port=qdrant_port)
collection_name = "rag_collection"
if client.collection_exists(collection_name):
client.delete_collection(collection_name)
client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(size=embedder.get_sentence_embedding_dimension(), distance=models.Distance.COSINE),
)
points = []
total_chunks = len(chunks)
for idx, chunk in enumerate(chunks, start=1):
# Qdrant point IDs must be positive integers
id_hash = int(hashlib.sha256(chunk["id"].encode("utf-8")).hexdigest(), 16) % (10**16)
vector = embedder.encode(chunk["text"]).tolist()
points.append(models.PointStruct(
id=id_hash,
vector=vector,
payload={
"text": chunk["text"],
"filename": chunk["id"].rsplit(".md_chunk", 1)[0],
"url": chunk.get("url", None)
}
))
print(f"[{idx}/{total_chunks}] Подготовлен чанк: {chunk['id']} -> ID: {id_hash}")
batch_size = 100
for i in range(0, total_chunks, batch_size):
batch = points[i : i + batch_size]
client.upsert(collection_name=collection_name, points=batch)
print(f"Записан батч {(i // batch_size) + 1}, содержащий {len(batch)} точек, всего записано: {min(i + batch_size, total_chunks)}/{total_chunks}")
print(f"Завершена запись всех {total_chunks} чанков в коллекцию '{collection_name}'.")
if __name__ == "__main__":
print(f"Инициализация...")
parser = argparse.ArgumentParser(description="Скрипт векторизаци данных для Qdrant")
parser.add_argument("--input_dir", type=str, default="input_md", help="Директория с Markdown-файлами для чтения")
parser.add_argument("--chunk_size", type=int, default=500, help="Размер чанка")
parser.add_argument("--chunk_overlap", type=int, default=100, help="Размер перекрытия")
parser.add_argument("--embedding_model", type=str, default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", help="Модель эмбеддинга")
parser.add_argument("--qdrant_host", type=str, default="localhost", help="Адрес хоста Qdrant")
parser.add_argument("--qdrant_port", type=int, default=6333, help="Номер порта Qdrant")
args = parser.parse_args()
documents = load_markdown_files(args.input_dir)
print(f"Найдено документов: {len(documents)}")
print(f"Подготовка чанков...")
chunks = chunk_text(documents, args.chunk_size, args.chunk_overlap)
print(f"Создано чанков: {len(chunks)} ({args.chunk_size}/{args.chunk_overlap})")
embed_and_upload(chunks, args.embedding_model, args.qdrant_host, args.qdrant_port)