WIP
This commit is contained in:
42
compose.yml
42
compose.yml
@@ -9,25 +9,25 @@ services:
|
|||||||
- "${OLLAMA_PORT:-11434}:11434"
|
- "${OLLAMA_PORT:-11434}:11434"
|
||||||
restart: "no"
|
restart: "no"
|
||||||
|
|
||||||
ai-qdrant:
|
# ai-qdrant:
|
||||||
container_name: ai-qdrant
|
# container_name: ai-qdrant
|
||||||
image: qdrant/qdrant
|
# image: qdrant/qdrant
|
||||||
env_file: .env
|
# env_file: .env
|
||||||
ports:
|
# ports:
|
||||||
- "${QDRANT_PORT:-6333}:6333"
|
# - "${QDRANT_PORT:-6333}:6333"
|
||||||
volumes:
|
# volumes:
|
||||||
- ./.data/qdrant/storage:/qdrant/storage
|
# - ./.data/qdrant/storage:/qdrant/storage
|
||||||
restart: "no"
|
# restart: "no"
|
||||||
profiles: ["rag"]
|
# profiles: ["rag"]
|
||||||
|
|
||||||
ai-webui:
|
# ai-webui:
|
||||||
container_name: ai-webui
|
# container_name: ai-webui
|
||||||
image: ghcr.io/open-webui/open-webui:main
|
# image: ghcr.io/open-webui/open-webui:main
|
||||||
env_file: .env
|
# env_file: .env
|
||||||
volumes:
|
# volumes:
|
||||||
- ./.data/webui:/app/backend/data
|
# - ./.data/webui:/app/backend/data
|
||||||
ports:
|
# ports:
|
||||||
- "${OWEBUI_PORT:-9999}:8080"
|
# - "${OWEBUI_PORT:-9999}:8080"
|
||||||
extra_hosts:
|
# extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
# - "host.docker.internal:host-gateway"
|
||||||
restart: "no"
|
# restart: "no"
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
cd ..; ./up; cd -
|
cd ..; ./up; cd -
|
||||||
python3 -m venv .venv
|
python3 -m venv .venv
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
pip install beautifulsoup4 markdownify sentence-transformers qdrant-client langchain transformers
|
pip install beautifulsoup4 markdownify sentence-transformers qdrant-client langchain transformers ollama
|
||||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||||
./download.sh 123456789 # <<== pageId страницы в Confluence
|
./download.sh 123456789 # <<== pageId страницы в Confluence
|
||||||
python3 convert.py
|
python3 convert.py
|
||||||
@@ -66,7 +66,7 @@ rag/
|
|||||||
```bash
|
```bash
|
||||||
python3 -m venv .venv
|
python3 -m venv .venv
|
||||||
source ./venv/bin/activate
|
source ./venv/bin/activate
|
||||||
pip install beautifulsoup4 markdownify sentence-transformers qdrant-client langchain transformers
|
pip install beautifulsoup4 markdownify sentence-transformers qdrant-client langchain transformers ollama
|
||||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
156
rag/rag.py
156
rag/rag.py
@@ -1,10 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import requests
|
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
import sys
|
import sys
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||||
|
import ollama
|
||||||
|
|
||||||
DEFAULT_CHAT_MODEL = "openchat:7b"
|
DEFAULT_CHAT_MODEL = "openchat:7b"
|
||||||
DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
DEFAULT_EMBED_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||||
@@ -38,33 +37,26 @@ class RagSystem:
|
|||||||
self.qdrant_port = qdrant_port
|
self.qdrant_port = qdrant_port
|
||||||
self.chat_model = chat_model
|
self.chat_model = chat_model
|
||||||
self.emb_model = SentenceTransformer(embed_model)
|
self.emb_model = SentenceTransformer(embed_model)
|
||||||
self.qdrant = QdrantClient(host=args.qdrant_host, port=args.qdrant_port)
|
self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
|
||||||
self.use_rank = use_rank
|
self.use_rank = use_rank
|
||||||
if self.use_rank:
|
if self.use_rank:
|
||||||
self.rank_model = CrossEncoder(rank_model)
|
self.rank_model = CrossEncoder(rank_model)
|
||||||
self.conversation_history = []
|
self.conversation_history = []
|
||||||
|
self.ollama = ollama.Client(base_url=ollama_url)
|
||||||
|
|
||||||
def check_chat_model(self):
|
def check_chat_model(self):
|
||||||
response = requests.get(f"{self.ollama_url}/api/tags")
|
models = self.ollama.list()
|
||||||
if response.status_code != 200:
|
return any(model.name == self.chat_model for model in models)
|
||||||
return False
|
|
||||||
for model in response.json().get("models", []):
|
|
||||||
if model["name"] == self.chat_model:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def install_chat_model(self, model: str = DEFAULT_CHAT_MODEL):
|
def install_chat_model(self, model: str = DEFAULT_CHAT_MODEL):
|
||||||
try:
|
try:
|
||||||
response = requests.post(f"{self.ollama_url}/api/pull", json={"model": model})
|
result = self.ollama.pull(model)
|
||||||
if response.status_code == 200:
|
print(f"Модель {model} установлена успешно")
|
||||||
print(f"Модель {self.chat_model} установлена успешно")
|
|
||||||
else:
|
|
||||||
print(f"Ошибка установки модели: {response.text}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Ошибка проверки модели: {str(e)}")
|
print(f"Ошибка установки модели: {str(e)}")
|
||||||
|
|
||||||
def load_chat_model(self):
|
def load_chat_model(self):
|
||||||
requests.post(f"{self.ollama_url}/api/generate", json={"model": self.chat_model}, timeout=600)
|
self.ollama.generate(model=self.chat_model, keep_alive=True)
|
||||||
|
|
||||||
def search_qdrant(self, query: str, doc_count: int = DEFAULT_TOP_K, collection_name = DEFAULT_QDRANT_COLLECTION):
|
def search_qdrant(self, query: str, doc_count: int = DEFAULT_TOP_K, collection_name = DEFAULT_QDRANT_COLLECTION):
|
||||||
query_vec = self.emb_model.encode(query, show_progress_bar=False).tolist()
|
query_vec = self.emb_model.encode(query, show_progress_bar=False).tolist()
|
||||||
@@ -100,85 +92,71 @@ class RagSystem:
|
|||||||
return ranked_docs[:top_n]
|
return ranked_docs[:top_n]
|
||||||
|
|
||||||
def generate_answer(self, sys_prompt: str, user_prompt: str):
|
def generate_answer(self, sys_prompt: str, user_prompt: str):
|
||||||
url = f"{self.ollama_url}/api/generate"
|
try:
|
||||||
body = {
|
with self.ollama.generate(
|
||||||
"model": self.chat_model,
|
model=self.chat_model,
|
||||||
"system": sys_prompt,
|
prompt=sys_prompt + "\n" + user_prompt,
|
||||||
"prompt": user_prompt,
|
options={
|
||||||
"stream": False,
|
"temperature": 0.5,
|
||||||
"options": {
|
},
|
||||||
"temperature": 0.5,
|
stream=False,
|
||||||
# "top_p": 0.2,
|
) as generator:
|
||||||
},
|
response = next(generator)
|
||||||
}
|
if response.error:
|
||||||
|
raise RuntimeError(f"Ошибка генерации: {response.error}")
|
||||||
response = requests.post(url, json=body, timeout=900)
|
self.last_response = response
|
||||||
if response.status_code != 200:
|
return response.output
|
||||||
return f"Ошибка генерации ответа: {response.status_code} {response.text}"
|
except Exception as e:
|
||||||
self.response = response.json()
|
print(f"Ошибка генерации ответа: {str(e)}")
|
||||||
return self.response["response"]
|
return str(e)
|
||||||
|
|
||||||
def generate_answer_stream(self, sys_prompt: str, user_prompt: str):
|
def generate_answer_stream(self, sys_prompt: str, user_prompt: str):
|
||||||
url = f"{self.ollama_url}/api/generate"
|
try:
|
||||||
body = {
|
generator = self.ollama.generate(
|
||||||
"model": self.chat_model,
|
model=self.chat_model,
|
||||||
"system": sys_prompt,
|
prompt=sys_prompt + "\n" + user_prompt,
|
||||||
"prompt": user_prompt,
|
options={
|
||||||
"stream": True,
|
"temperature": 0.5,
|
||||||
"options": {
|
},
|
||||||
"temperature": 0.5,
|
stream=True,
|
||||||
# "top_p": 0.2,
|
)
|
||||||
},
|
answer = ""
|
||||||
}
|
for response in generator:
|
||||||
resp = requests.post(url, json=body, stream=True, timeout=900)
|
if response.data:
|
||||||
if resp.status_code != 200:
|
yield response.data
|
||||||
raise RuntimeError(f"Ошибка генерации ответа: {resp.status_code} {resp.text}")
|
answer += response.data
|
||||||
|
if response.done:
|
||||||
answer = ""
|
self.last_response = response
|
||||||
self.response = None
|
break
|
||||||
for chunk in resp.iter_lines():
|
return answer
|
||||||
if chunk:
|
except Exception as e:
|
||||||
try:
|
print(f"Ошибка стриминга: {str(e)}")
|
||||||
decoded_chunk = chunk.decode('utf-8')
|
return str(e)
|
||||||
data = json.loads(decoded_chunk)
|
|
||||||
if "response" in data:
|
|
||||||
yield data["response"]
|
|
||||||
answer += data["response"]
|
|
||||||
if "done" in data and data["done"] is True:
|
|
||||||
self.response = data
|
|
||||||
break
|
|
||||||
elif "error" in data:
|
|
||||||
answer += f" | Ошибка стриминга ответа: {data['error']}"
|
|
||||||
break
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
answer += f" | Ошибка конвертации чанка: {chunk.decode('utf-8')} - {e}"
|
|
||||||
except Exception as e:
|
|
||||||
answer += f" | Ошибка обработки чанка: {e}"
|
|
||||||
|
|
||||||
def get_prompt_eval_count(self):
|
def get_prompt_eval_count(self):
|
||||||
if not self.response:
|
if not hasattr(self, "last_response"):
|
||||||
return 0
|
return 0
|
||||||
return self.response["prompt_eval_count"]
|
return self.last_response.prompt_eval_count or 0
|
||||||
|
|
||||||
def get_prompt_eval_duration(self):
|
def get_prompt_eval_duration(self):
|
||||||
if not self.response:
|
if not hasattr(self, "last_response"):
|
||||||
return 0
|
return 0
|
||||||
return self.response["prompt_eval_duration"] / (10 ** 9)
|
return self.last_response.prompt_eval_duration / (10 ** 9)
|
||||||
|
|
||||||
def get_eval_count(self):
|
def get_eval_count(self):
|
||||||
if not self.response:
|
if not hasattr(self, "last_response"):
|
||||||
return 0
|
return 0
|
||||||
return self.response["eval_count"]
|
return self.last_response.eval_count or 0
|
||||||
|
|
||||||
def get_eval_duration(self):
|
def get_eval_duration(self):
|
||||||
if not self.response:
|
if not hasattr(self, "last_response"):
|
||||||
return 0
|
return 0
|
||||||
return self.response["eval_duration"] / (10 ** 9)
|
return self.last_response.eval_duration / (10 ** 9)
|
||||||
|
|
||||||
def get_total_duration(self):
|
def get_total_duration(self):
|
||||||
if not self.response:
|
if not hasattr(self, "last_response"):
|
||||||
return 0
|
return 0
|
||||||
return self.response["total_duration"] / (10 ** 9)
|
return self.last_response.total_duration / (10 ** 9)
|
||||||
|
|
||||||
def get_tps(self):
|
def get_tps(self):
|
||||||
eval_count = self.get_eval_count()
|
eval_count = self.get_eval_count()
|
||||||
@@ -360,19 +338,23 @@ Context:
|
|||||||
|
|
||||||
def process_query(self, sys_prompt: str, user_prompt: str, streaming: bool = DEFAULT_STREAM):
|
def process_query(self, sys_prompt: str, user_prompt: str, streaming: bool = DEFAULT_STREAM):
|
||||||
answer = ""
|
answer = ""
|
||||||
# try:
|
|
||||||
if streaming:
|
if streaming:
|
||||||
self.print_v(text="\nГенерация потокового ответа (^C для остановки)...\n")
|
self.print_v(text="\nГенерация потокового ответа (^C для остановки)...\n")
|
||||||
print(f"<<< ", end='', flush=True)
|
print(f"<<< ", end='', flush=True)
|
||||||
for token in self.rag.generate_answer_stream(sys_prompt, user_prompt):
|
try:
|
||||||
answer += token
|
for token in self.rag.generate_answer_stream(sys_prompt, user_prompt):
|
||||||
print(token, end='', flush=True)
|
answer += token
|
||||||
|
print(token, end='', flush=True)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n*** Генерация ответа прервана")
|
||||||
|
return answer
|
||||||
else:
|
else:
|
||||||
self.print_v(text="\nГенерация ответа (^C для остановки)...\n")
|
self.print_v(text="\nГенерация ответа (^C для остановки)...\n")
|
||||||
answer = self.rag.generate_answer(sys_prompt, user_prompt)
|
try:
|
||||||
print(f"<<< {answer}\n")
|
answer = self.rag.generate_answer(sys_prompt, user_prompt)
|
||||||
# except RuntimeError as e:
|
except KeyboardInterrupt:
|
||||||
# answer = str(e)
|
print("\n*** Генерация ответа прервана")
|
||||||
|
return ""
|
||||||
|
|
||||||
print(f"\n===================================================")
|
print(f"\n===================================================")
|
||||||
return answer
|
return answer
|
||||||
|
|||||||
Reference in New Issue
Block a user