Автозагрузка модели ollama при отсутствии
This commit is contained in:
33
rag/rag.py
33
rag/rag.py
@@ -8,9 +8,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
|
||||
DEFAULT_CHAT_MODEL = "openchat:7b"
|
||||
DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
# DEFAULT_RANK_MODEL = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
|
||||
DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
# DEFAULT_RANK_MODEL = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
|
||||
DEFAULT_MD_FOLDER = "data"
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
DEFAULT_QDRANT_HOST = "localhost"
|
||||
@@ -45,12 +43,28 @@ class RagSystem:
|
||||
if self.use_rank:
|
||||
self.rank_model = CrossEncoder(rank_model)
|
||||
self.conversation_history = []
|
||||
self.load_chat_model()
|
||||
|
||||
def check_chat_model(self):
|
||||
response = requests.get(f"{self.ollama_url}/api/tags")
|
||||
if response.status_code != 200:
|
||||
return False
|
||||
for model in response.json().get("models", []):
|
||||
if model["name"] == self.chat_model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def install_chat_model(self, model: str = DEFAULT_CHAT_MODEL):
|
||||
try:
|
||||
response = requests.post(f"{self.ollama_url}/api/pull", json={"model": model})
|
||||
if response.status_code == 200:
|
||||
print(f"Модель {self.chat_model} установлена успешно")
|
||||
else:
|
||||
print(f"Ошибка установки модели: {response.text}")
|
||||
except Exception as e:
|
||||
print(f"Ошибка проверки модели: {str(e)}")
|
||||
|
||||
def load_chat_model(self):
|
||||
url = f"{self.ollama_url}/api/generate"
|
||||
body = {"model": self.chat_model}
|
||||
requests.post(url, json=body, timeout=600)
|
||||
requests.post(f"{self.ollama_url}/api/generate", json={"model": self.chat_model}, timeout=600)
|
||||
|
||||
def search_qdrant(self, query: str, doc_count: int = DEFAULT_TOP_K, collection_name = DEFAULT_QDRANT_COLLECTION):
|
||||
query_vec = self.emb_model.encode(query, show_progress_bar=False).tolist()
|
||||
@@ -219,6 +233,10 @@ class App:
|
||||
use_rank = self.args.use_rank,
|
||||
chat_model = self.args.chat_model
|
||||
)
|
||||
if not self.rag.check_chat_model():
|
||||
print(f"Установка модели {self.args.chat_model} ...")
|
||||
self.rag.install_chat_model(self.args.chat_model)
|
||||
self.rag.load_chat_model()
|
||||
self.print_v(text=f"Модели загружены. Если ответ плохой, переформулируйте запрос, укажите --chat-model или улучшите исходные данные RAG")
|
||||
|
||||
def init_query(self):
|
||||
@@ -234,8 +252,9 @@ class App:
|
||||
self.query = input(">>> ").strip()
|
||||
|
||||
def process_help(self):
|
||||
print("<<< Команды итерактивного режима:")
|
||||
print("<<< Команды интерактивного режима:")
|
||||
print("save -- сохранить диалог в файл")
|
||||
print("stats -- статистика последнего ответа")
|
||||
print("exit -- выход\n")
|
||||
self.query = None
|
||||
self.args.query = None
|
||||
|
||||
Reference in New Issue
Block a user