89 lines
3.1 KiB
Python
89 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from typing import Any
|
|
|
|
try:
|
|
from redis import asyncio as aioredis
|
|
except ModuleNotFoundError:
|
|
aioredis = None
|
|
|
|
from app.utils.ollama_client import chat_json, summarize_dialog_text
|
|
|
|
|
|
class DialogMemoryService:
|
|
def __init__(self) -> None:
|
|
redis_host = os.getenv("REDIS_HOST", "localhost")
|
|
redis_port = os.getenv("REDIS_PORT", "6379")
|
|
self.redis_url = os.getenv("REDIS_URL", f"redis://{redis_host}:{redis_port}")
|
|
self.ttl_seconds = int(os.getenv("DIALOG_TTL_SECONDS", "86400"))
|
|
|
|
async def get_context(self, dialog_id: str) -> tuple[list[dict[str, Any]], str | None]:
|
|
redis = await self._get_redis()
|
|
if redis is None:
|
|
return [], None
|
|
|
|
messages_raw = await redis.get(self._messages_key(dialog_id))
|
|
summary = await redis.get(self._summary_key(dialog_id))
|
|
messages = self._decode_messages(messages_raw)
|
|
return messages, summary
|
|
|
|
async def append_and_summarize(self, dialog_id: str, role: str, content: str) -> str | None:
|
|
redis = await self._get_redis()
|
|
if redis is None:
|
|
return None
|
|
|
|
messages_key = self._messages_key(dialog_id)
|
|
summary_key = self._summary_key(dialog_id)
|
|
|
|
current_messages = self._decode_messages(await redis.get(messages_key))
|
|
current_messages.append({"role": role, "content": content})
|
|
await redis.set(messages_key, json.dumps(current_messages, ensure_ascii=False), ex=self.ttl_seconds)
|
|
|
|
try:
|
|
summary = await summarize_dialog_text(current_messages)
|
|
except Exception:
|
|
summary = None
|
|
if summary is None:
|
|
summary = self._fallback_summary(current_messages)
|
|
await redis.set(summary_key, summary, ex=self.ttl_seconds)
|
|
return summary
|
|
|
|
async def reset(self, dialog_id: str) -> None:
|
|
redis = await self._get_redis()
|
|
if redis is None:
|
|
return
|
|
await redis.delete(self._messages_key(dialog_id), self._summary_key(dialog_id))
|
|
|
|
async def _get_redis(self):
|
|
if aioredis is None:
|
|
return None
|
|
try:
|
|
redis = aioredis.from_url(self.redis_url, encoding="utf8", decode_responses=True)
|
|
await redis.ping()
|
|
return redis
|
|
except Exception:
|
|
return None
|
|
|
|
def _messages_key(self, dialog_id: str) -> str:
|
|
return f"dialog:{dialog_id}:messages"
|
|
|
|
def _summary_key(self, dialog_id: str) -> str:
|
|
return f"dialog:{dialog_id}:summary"
|
|
|
|
def _decode_messages(self, payload: str | None) -> list[dict[str, Any]]:
|
|
if not payload:
|
|
return []
|
|
try:
|
|
parsed = json.loads(payload)
|
|
except json.JSONDecodeError:
|
|
return []
|
|
if not isinstance(parsed, list):
|
|
return []
|
|
return [item for item in parsed if isinstance(item, dict)]
|
|
|
|
def _fallback_summary(self, messages: list[dict[str, Any]]) -> str:
|
|
chunks = [str(item.get("content", "")) for item in messages[-4:]]
|
|
return "\n".join(chunk for chunk in chunks if chunk)
|