diff --git a/rag/README.md b/rag/README.md index 1720a9a..9c76a82 100644 --- a/rag/README.md +++ b/rag/README.md @@ -32,6 +32,8 @@ RAG (Retrieval-Augmented Generation) — это архитектура, кото - Дает возможность проверить источник информации в сгенерированном ответе - Может работать с проприетарными или конфиденциальными данными без дообучения модели +Прочесть подробнее можно здесь: https://habr.com/ru/articles/904032/ + ## Структура проекта ``` @@ -233,6 +235,7 @@ python3 rag.py --help **Цель:** изучить современные технологии. **Задачи:** + 1. облегчить поиск информации о проекте среди почти 2000 тысяч документов в корпоративной Confluence, относящихся к нему; 2. обеспечить минимум телодвижений для развёртывания RAG с нуля внутри команды. diff --git a/rag/vectorize.py b/rag/vectorize.py index 5102a0f..98975fe 100644 --- a/rag/vectorize.py +++ b/rag/vectorize.py @@ -5,6 +5,14 @@ from qdrant_client import QdrantClient from qdrant_client.http import models from langchain.text_splitter import RecursiveCharacterTextSplitter +DEFAULT_INPUT_DIR="data" +DEFAULT_CHUNK_SIZE=500 +DEFAULT_CHUNK_OVERLAP=100 +DEFAULT_EMBEDDING_MODEL="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" +DEFAULT_QDRANT_HOST="localhost" +DEFAULT_QDRANT_PORT=6333 +DEFAULT_QDRANT_COLLECTION="rag" +BATCH_SIZE = 100 def load_markdown_files(input_dir): documents = [] @@ -20,15 +28,12 @@ def load_markdown_files(input_dir): date = None if lines: - # Проверка первой строки на URL - if lines[0].strip().startswith("@@") and lines[0].strip().endswith("@@") and len(lines[0].strip()) > 4: - url = lines[0].strip()[2:-2].strip() - lines = lines[1:] - - # Проверка оставшихся строк на метаданные i = 0 while i < len(lines): line = lines[i].strip() + if line.strip().startswith("@@") and line.strip().endswith("@@") and len(line.strip()) > 4: + url = line.strip()[2:-2].strip() + lines = lines[1:] if line.startswith("^^") and line.endswith("^^") and len(line) > 4: version = line[2:-2].strip() lines.pop(i) @@ -42,11 +47,14 @@ def load_markdown_files(input_dir): i += 1 doc_metadata = {"id": filename, "text": "\n".join(lines)} - if url: doc_metadata["url"] = url - if version: doc_metadata["version"] = version - if author: doc_metadata["author"] = author - if date: doc_metadata["date"] = date - + if url: + doc_metadata["url"] = url + if version: + doc_metadata["version"] = version + if author: + doc_metadata["author"] = author + if date: + doc_metadata["date"] = date documents.append(doc_metadata) return documents @@ -109,23 +117,22 @@ def embed_and_upload(chunks, embedding_model_name, qdrant_host="localhost", qdra )) print(f"[{idx}/{total_chunks}] Подготовлен чанк: {chunk['id']} -> ID: {id_hash}") - batch_size = 100 - for i in range(0, total_chunks, batch_size): - batch = points[i : i + batch_size] + for i in range(0, total_chunks, BATCH_SIZE): + batch = points[i : i + BATCH_SIZE] client.upsert(collection_name=qdrant_collection, points=batch) - print(f"Записан батч {(i // batch_size) + 1}, содержащий {len(batch)} точек, всего записано: {min(i + batch_size, total_chunks)}/{total_chunks}") + print(f"Записан батч {(i // BATCH_SIZE) + 1}, содержащий {len(batch)} точек, всего записано: {min(i + BATCH_SIZE, total_chunks)}/{total_chunks}") print(f"Завершена запись всех {total_chunks} чанков в коллекцию '{qdrant_collection}'.") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Скрипт векторизаци данных для Qdrant") - parser.add_argument("--input-dir", type=str, default="data", help="Директория с Markdown-файлами для чтения") - parser.add_argument("--chunk-size", type=int, default=500, help="Размер чанка") - parser.add_argument("--chunk-overlap", type=int, default=100, help="Размер перекрытия") - parser.add_argument("--embedding-model", type=str, default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", help="Модель эмбеддинга") - parser.add_argument("--qdrant-host", type=str, default="localhost", help="Адрес хоста Qdrant") - parser.add_argument("--qdrant-port", type=int, default=6333, help="Номер порта Qdrant") - parser.add_argument("--qdrant-collection", type=str, default="rag", help="Название коллекции для сохранения документов") + parser.add_argument("--input-dir", type=str, default=DEFAULT_INPUT_DIR, help="Директория с Markdown-файлами для чтения") + parser.add_argument("--chunk-size", type=int, default=DEFAULT_CHUNK_SIZE, help="Размер чанка") + parser.add_argument("--chunk-overlap", type=int, default=DEFAULT_CHUNK_OVERLAP, help="Размер перекрытия") + parser.add_argument("--embedding-model", type=str, default=DEFAULT_EMBEDDING_MODEL, help="Модель эмбеддинга") + parser.add_argument("--qdrant-host", type=str, default=DEFAULT_QDRANT_HOST, help="Адрес хоста Qdrant") + parser.add_argument("--qdrant-port", type=int, default=DEFAULT_QDRANT_PORT, help="Номер порта Qdrant") + parser.add_argument("--qdrant-collection", type=str, default=DEFAULT_QDRANT_COLLECTION, help="Название коллекции для сохранения документов") args = parser.parse_args() documents = load_markdown_files(args.input_dir) @@ -135,4 +142,10 @@ if __name__ == "__main__": chunks = chunk_text(documents, args.chunk_size, args.chunk_overlap) print(f"Создано чанков: {len(chunks)} ({args.chunk_size}/{args.chunk_overlap})") - embed_and_upload(chunks, args.embedding_model, args.qdrant_host, args.qdrant_port, args.qdrant_collection) + embed_and_upload( + chunks, + args.embedding_model, + args.qdrant_host, + args.qdrant_port, + args.qdrant_collection + )