diff --git a/rag/rag.py b/rag/rag.py index 09f35f9..b8510eb 100644 --- a/rag/rag.py +++ b/rag/rag.py @@ -8,9 +8,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder DEFAULT_CHAT_MODEL = "openchat:7b" DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" -# DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" -# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-TinyBERT-L-2-v2" DEFAULT_MD_FOLDER = "data" DEFAULT_OLLAMA_URL = "http://localhost:11434" DEFAULT_QDRANT_HOST = "localhost" @@ -45,12 +43,28 @@ class RagSystem: if self.use_rank: self.rank_model = CrossEncoder(rank_model) self.conversation_history = [] - self.load_chat_model() + + def check_chat_model(self): + response = requests.get(f"{self.ollama_url}/api/tags") + if response.status_code != 200: + return False + for model in response.json().get("models", []): + if model["name"] == self.chat_model: + return True + return False + + def install_chat_model(self, model: str = DEFAULT_CHAT_MODEL): + try: + response = requests.post(f"{self.ollama_url}/api/pull", json={"model": model}) + if response.status_code == 200: + print(f"Модель {self.chat_model} установлена успешно") + else: + print(f"Ошибка установки модели: {response.text}") + except Exception as e: + print(f"Ошибка проверки модели: {str(e)}") def load_chat_model(self): - url = f"{self.ollama_url}/api/generate" - body = {"model": self.chat_model} - requests.post(url, json=body, timeout=600) + requests.post(f"{self.ollama_url}/api/generate", json={"model": self.chat_model}, timeout=600) def search_qdrant(self, query: str, doc_count: int = DEFAULT_TOP_K, collection_name = DEFAULT_QDRANT_COLLECTION): query_vec = self.emb_model.encode(query, show_progress_bar=False).tolist() @@ -219,6 +233,10 @@ class App: use_rank = self.args.use_rank, chat_model = self.args.chat_model ) + if not self.rag.check_chat_model(): + print(f"Установка модели {self.args.chat_model} ...") + self.rag.install_chat_model(self.args.chat_model) + self.rag.load_chat_model() self.print_v(text=f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG") def init_query(self): @@ -234,8 +252,9 @@ class App: self.query = input(">>> ").strip() def process_help(self): - print("<<< Команды итерактивного режима:") + print("<<< Команды интерактивного режима:") print("save -- сохранить диалог в файл") + print("stats -- статистика последнего ответа") print("exit -- выход\n") self.query = None self.args.query = None