1
0
Files
ollama/rag/vectorize.py

175 lines
7.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import argparse
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.http import models
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.text_splitter import MarkdownHeaderTextSplitter
DEFAULT_INPUT_DIR="data"
DEFAULT_CHUNK_SIZE=500
DEFAULT_CHUNK_OVERLAP=100
DEFAULT_EMBEDDING_MODEL="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
DEFAULT_QDRANT_HOST="localhost"
DEFAULT_QDRANT_PORT=6333
DEFAULT_QDRANT_COLLECTION="rag"
BATCH_SIZE = 100
def load_markdown_files(input_dir):
documents = []
for filename in os.listdir(input_dir):
if filename.endswith(".md"):
path = os.path.join(input_dir, filename)
with open(path, "r", encoding="utf-8") as f:
content = f.read()
lines = content.splitlines()
url = None
version = None
author = None
date = None
if lines:
i = 0
while i < len(lines):
line = lines[i].strip()
if line.strip().startswith("@@") and line.strip().endswith("@@") and len(line.strip()) > 4:
url = line.strip()[2:-2].strip()
lines = lines[1:]
if line.startswith("^^") and line.endswith("^^") and len(line) > 4:
version = line[2:-2].strip()
lines.pop(i)
elif line.startswith("%%") and line.endswith("%%") and len(line) > 4:
author = line[2:-2].strip()
lines.pop(i)
elif line.startswith("==") and line.endswith("==") and len(line) > 4:
date = line[2:-2].strip()
lines.pop(i)
else:
i += 1
doc_metadata = {"id": filename, "text": "\n".join(lines)}
if url:
doc_metadata["url"] = url
if version:
doc_metadata["version"] = version
if author:
doc_metadata["author"] = author
if date:
doc_metadata["date"] = date
documents.append(doc_metadata)
return documents
def chunk_text(texts, chunk_size, chunk_overlap):
markdown_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=[
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
],
strip_headers=False,
return_each_line=False,
)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
add_start_index=True,
length_function=len,
separators=["\n\n", "\n", " ", ""]
)
chunks = []
for doc in texts:
md_header_splits = markdown_splitter.split_text(doc["text"])
for md_split in md_header_splits:
# RecursiveCharacterTextSplitter for each markdown split
split_docs = text_splitter.split_documents([md_split])
for i, chunk in enumerate(split_docs):
chunk_id = f"{doc['id']}_chunk{i}"
chunk_dict = {"id": chunk_id, "text": chunk.page_content}
# Перенос всех доступных метаданных, включая метаданные из MarkdownHeaderTextSplitter
for key in ["url", "version", "author", "date"]:
if key in doc and doc[key] is not None:
chunk_dict[key] = doc[key]
# Добавление метаданных из MarkdownHeaderTextSplitter
for key, value in chunk.metadata.items():
chunk_dict[key] = value
chunks.append(chunk_dict)
return chunks
def embed_and_upload(chunks, embedding_model_name, qdrant_host="localhost", qdrant_port=6333, qdrant_collection="rag"):
import hashlib
print(f"Инициализация модели {args.embedding_model}")
embedder = SentenceTransformer(embedding_model_name)
print(f"Подключение к qdrant ({qdrant_host}:{qdrant_port})")
client = QdrantClient(host=qdrant_host, port=qdrant_port)
if client.collection_exists(qdrant_collection):
client.delete_collection(qdrant_collection)
client.create_collection(
collection_name=qdrant_collection,
vectors_config=models.VectorParams(size=embedder.get_sentence_embedding_dimension(), distance=models.Distance.COSINE),
)
points = []
total_chunks = len(chunks)
for idx, chunk in enumerate(chunks, start=1):
# Qdrant point IDs must be positive integers
id_hash = int(hashlib.sha256(chunk["id"].encode("utf-8")).hexdigest(), 16) % (10**16)
vector = embedder.encode(chunk["text"]).tolist()
points.append(models.PointStruct(
id=id_hash,
vector=vector,
payload={
"text": chunk["text"],
"filename": chunk["id"].rsplit(".md_chunk", 1)[0],
"url": chunk.get("url", None),
"version": chunk.get("version", None),
"author": chunk.get("author", None),
"date": chunk.get("date", None)
}
))
print(f"[{idx}/{total_chunks}] Подготовлен чанк: {chunk['id']} -> ID: {id_hash}")
for i in range(0, total_chunks, BATCH_SIZE):
batch = points[i : i + BATCH_SIZE]
client.upsert(collection_name=qdrant_collection, points=batch)
print(f"Записан батч {(i // BATCH_SIZE) + 1}, содержащий {len(batch)} точек, всего записано: {min(i + BATCH_SIZE, total_chunks)}/{total_chunks}")
print(f"Завершена запись всех {total_chunks} чанков в коллекцию '{qdrant_collection}'.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Скрипт векторизаци данных для Qdrant")
parser.add_argument("--input-dir", type=str, default=DEFAULT_INPUT_DIR, help="Директория с Markdown-файлами для чтения")
parser.add_argument("--chunk-size", type=int, default=DEFAULT_CHUNK_SIZE, help="Размер чанка")
parser.add_argument("--chunk-overlap", type=int, default=DEFAULT_CHUNK_OVERLAP, help="Размер перекрытия")
parser.add_argument("--embedding-model", type=str, default=DEFAULT_EMBEDDING_MODEL, help="Модель эмбеддинга")
parser.add_argument("--qdrant-host", type=str, default=DEFAULT_QDRANT_HOST, help="Адрес хоста Qdrant")
parser.add_argument("--qdrant-port", type=int, default=DEFAULT_QDRANT_PORT, help="Номер порта Qdrant")
parser.add_argument("--qdrant-collection", type=str, default=DEFAULT_QDRANT_COLLECTION, help="Название коллекции для сохранения документов")
args = parser.parse_args()
documents = load_markdown_files(args.input_dir)
print(f"Найдено документов: {len(documents)}")
print(f"Подготовка чанков...")
chunks = chunk_text(documents, args.chunk_size, args.chunk_overlap)
print(f"Создано чанков: {len(chunks)} ({args.chunk_size}/{args.chunk_overlap})")
embed_and_upload(
chunks,
args.embedding_model,
args.qdrant_host,
args.qdrant_port,
args.qdrant_collection
)