175 lines
7.4 KiB
Python
175 lines
7.4 KiB
Python
import os
|
||
import argparse
|
||
from sentence_transformers import SentenceTransformer
|
||
from qdrant_client import QdrantClient
|
||
from qdrant_client.http import models
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
from langchain.text_splitter import MarkdownHeaderTextSplitter
|
||
|
||
DEFAULT_INPUT_DIR="data"
|
||
DEFAULT_CHUNK_SIZE=500
|
||
DEFAULT_CHUNK_OVERLAP=100
|
||
DEFAULT_EMBEDDING_MODEL="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||
DEFAULT_QDRANT_HOST="localhost"
|
||
DEFAULT_QDRANT_PORT=6333
|
||
DEFAULT_QDRANT_COLLECTION="rag"
|
||
BATCH_SIZE = 100
|
||
|
||
def load_markdown_files(input_dir):
|
||
documents = []
|
||
for filename in os.listdir(input_dir):
|
||
if filename.endswith(".md"):
|
||
path = os.path.join(input_dir, filename)
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
lines = content.splitlines()
|
||
url = None
|
||
version = None
|
||
author = None
|
||
date = None
|
||
|
||
if lines:
|
||
i = 0
|
||
while i < len(lines):
|
||
line = lines[i].strip()
|
||
if line.strip().startswith("@@") and line.strip().endswith("@@") and len(line.strip()) > 4:
|
||
url = line.strip()[2:-2].strip()
|
||
lines = lines[1:]
|
||
if line.startswith("^^") and line.endswith("^^") and len(line) > 4:
|
||
version = line[2:-2].strip()
|
||
lines.pop(i)
|
||
elif line.startswith("%%") and line.endswith("%%") and len(line) > 4:
|
||
author = line[2:-2].strip()
|
||
lines.pop(i)
|
||
elif line.startswith("==") and line.endswith("==") and len(line) > 4:
|
||
date = line[2:-2].strip()
|
||
lines.pop(i)
|
||
else:
|
||
i += 1
|
||
|
||
doc_metadata = {"id": filename, "text": "\n".join(lines)}
|
||
if url:
|
||
doc_metadata["url"] = url
|
||
if version:
|
||
doc_metadata["version"] = version
|
||
if author:
|
||
doc_metadata["author"] = author
|
||
if date:
|
||
doc_metadata["date"] = date
|
||
documents.append(doc_metadata)
|
||
return documents
|
||
|
||
def chunk_text(texts, chunk_size, chunk_overlap):
|
||
markdown_splitter = MarkdownHeaderTextSplitter(
|
||
headers_to_split_on=[
|
||
("#", "Header 1"),
|
||
("##", "Header 2"),
|
||
("###", "Header 3"),
|
||
],
|
||
strip_headers=False,
|
||
return_each_line=False,
|
||
)
|
||
text_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=chunk_size,
|
||
chunk_overlap=chunk_overlap,
|
||
add_start_index=True,
|
||
length_function=len,
|
||
separators=["\n\n", "\n", " ", ""]
|
||
)
|
||
|
||
chunks = []
|
||
for doc in texts:
|
||
md_header_splits = markdown_splitter.split_text(doc["text"])
|
||
|
||
for md_split in md_header_splits:
|
||
# RecursiveCharacterTextSplitter for each markdown split
|
||
split_docs = text_splitter.split_documents([md_split])
|
||
|
||
for i, chunk in enumerate(split_docs):
|
||
chunk_id = f"{doc['id']}_chunk{i}"
|
||
chunk_dict = {"id": chunk_id, "text": chunk.page_content}
|
||
|
||
# Перенос всех доступных метаданных, включая метаданные из MarkdownHeaderTextSplitter
|
||
for key in ["url", "version", "author", "date"]:
|
||
if key in doc and doc[key] is not None:
|
||
chunk_dict[key] = doc[key]
|
||
|
||
# Добавление метаданных из MarkdownHeaderTextSplitter
|
||
for key, value in chunk.metadata.items():
|
||
chunk_dict[key] = value
|
||
|
||
chunks.append(chunk_dict)
|
||
return chunks
|
||
|
||
def embed_and_upload(chunks, embedding_model_name, qdrant_host="localhost", qdrant_port=6333, qdrant_collection="rag"):
|
||
import hashlib
|
||
|
||
print(f"Инициализация модели {args.embedding_model}")
|
||
embedder = SentenceTransformer(embedding_model_name)
|
||
|
||
print(f"Подключение к qdrant ({qdrant_host}:{qdrant_port})")
|
||
client = QdrantClient(host=qdrant_host, port=qdrant_port)
|
||
|
||
if client.collection_exists(qdrant_collection):
|
||
client.delete_collection(qdrant_collection)
|
||
|
||
client.create_collection(
|
||
collection_name=qdrant_collection,
|
||
vectors_config=models.VectorParams(size=embedder.get_sentence_embedding_dimension(), distance=models.Distance.COSINE),
|
||
)
|
||
|
||
points = []
|
||
total_chunks = len(chunks)
|
||
for idx, chunk in enumerate(chunks, start=1):
|
||
# Qdrant point IDs must be positive integers
|
||
id_hash = int(hashlib.sha256(chunk["id"].encode("utf-8")).hexdigest(), 16) % (10**16)
|
||
vector = embedder.encode(chunk["text"]).tolist()
|
||
|
||
points.append(models.PointStruct(
|
||
id=id_hash,
|
||
vector=vector,
|
||
payload={
|
||
"text": chunk["text"],
|
||
"filename": chunk["id"].rsplit(".md_chunk", 1)[0],
|
||
"url": chunk.get("url", None),
|
||
"version": chunk.get("version", None),
|
||
"author": chunk.get("author", None),
|
||
"date": chunk.get("date", None)
|
||
}
|
||
))
|
||
print(f"[{idx}/{total_chunks}] Подготовлен чанк: {chunk['id']} -> ID: {id_hash}")
|
||
|
||
for i in range(0, total_chunks, BATCH_SIZE):
|
||
batch = points[i : i + BATCH_SIZE]
|
||
client.upsert(collection_name=qdrant_collection, points=batch)
|
||
print(f"Записан батч {(i // BATCH_SIZE) + 1}, содержащий {len(batch)} точек, всего записано: {min(i + BATCH_SIZE, total_chunks)}/{total_chunks}")
|
||
|
||
print(f"Завершена запись всех {total_chunks} чанков в коллекцию '{qdrant_collection}'.")
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="Скрипт векторизаци данных для Qdrant")
|
||
parser.add_argument("--input-dir", type=str, default=DEFAULT_INPUT_DIR, help="Директория с Markdown-файлами для чтения")
|
||
parser.add_argument("--chunk-size", type=int, default=DEFAULT_CHUNK_SIZE, help="Размер чанка")
|
||
parser.add_argument("--chunk-overlap", type=int, default=DEFAULT_CHUNK_OVERLAP, help="Размер перекрытия")
|
||
parser.add_argument("--embedding-model", type=str, default=DEFAULT_EMBEDDING_MODEL, help="Модель эмбеддинга")
|
||
parser.add_argument("--qdrant-host", type=str, default=DEFAULT_QDRANT_HOST, help="Адрес хоста Qdrant")
|
||
parser.add_argument("--qdrant-port", type=int, default=DEFAULT_QDRANT_PORT, help="Номер порта Qdrant")
|
||
parser.add_argument("--qdrant-collection", type=str, default=DEFAULT_QDRANT_COLLECTION, help="Название коллекции для сохранения документов")
|
||
args = parser.parse_args()
|
||
|
||
documents = load_markdown_files(args.input_dir)
|
||
print(f"Найдено документов: {len(documents)}")
|
||
|
||
print(f"Подготовка чанков...")
|
||
chunks = chunk_text(documents, args.chunk_size, args.chunk_overlap)
|
||
print(f"Создано чанков: {len(chunks)} ({args.chunk_size}/{args.chunk_overlap})")
|
||
|
||
embed_and_upload(
|
||
chunks,
|
||
args.embedding_model,
|
||
args.qdrant_host,
|
||
args.qdrant_port,
|
||
args.qdrant_collection
|
||
)
|
||
|