upload
This commit is contained in:
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from redis import asyncio as aioredis
|
||||
except ModuleNotFoundError:
|
||||
aioredis = None
|
||||
|
||||
from app.utils.ollama_client import chat_json, summarize_dialog_text
|
||||
|
||||
|
||||
class DialogMemoryService:
|
||||
def __init__(self) -> None:
|
||||
redis_host = os.getenv("REDIS_HOST", "localhost")
|
||||
redis_port = os.getenv("REDIS_PORT", "6379")
|
||||
self.redis_url = os.getenv("REDIS_URL", f"redis://{redis_host}:{redis_port}")
|
||||
self.ttl_seconds = int(os.getenv("DIALOG_TTL_SECONDS", "86400"))
|
||||
|
||||
async def get_context(self, dialog_id: str) -> tuple[list[dict[str, Any]], str | None]:
|
||||
redis = await self._get_redis()
|
||||
if redis is None:
|
||||
return [], None
|
||||
|
||||
messages_raw = await redis.get(self._messages_key(dialog_id))
|
||||
summary = await redis.get(self._summary_key(dialog_id))
|
||||
messages = self._decode_messages(messages_raw)
|
||||
return messages, summary
|
||||
|
||||
async def append_and_summarize(self, dialog_id: str, role: str, content: str) -> str | None:
|
||||
redis = await self._get_redis()
|
||||
if redis is None:
|
||||
return None
|
||||
|
||||
messages_key = self._messages_key(dialog_id)
|
||||
summary_key = self._summary_key(dialog_id)
|
||||
|
||||
current_messages = self._decode_messages(await redis.get(messages_key))
|
||||
current_messages.append({"role": role, "content": content})
|
||||
await redis.set(messages_key, json.dumps(current_messages, ensure_ascii=False), ex=self.ttl_seconds)
|
||||
|
||||
try:
|
||||
summary = await summarize_dialog_text(current_messages)
|
||||
except Exception:
|
||||
summary = None
|
||||
if summary is None:
|
||||
summary = self._fallback_summary(current_messages)
|
||||
await redis.set(summary_key, summary, ex=self.ttl_seconds)
|
||||
return summary
|
||||
|
||||
async def reset(self, dialog_id: str) -> None:
|
||||
redis = await self._get_redis()
|
||||
if redis is None:
|
||||
return
|
||||
await redis.delete(self._messages_key(dialog_id), self._summary_key(dialog_id))
|
||||
|
||||
async def _get_redis(self):
|
||||
if aioredis is None:
|
||||
return None
|
||||
try:
|
||||
redis = aioredis.from_url(self.redis_url, encoding="utf8", decode_responses=True)
|
||||
await redis.ping()
|
||||
return redis
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _messages_key(self, dialog_id: str) -> str:
|
||||
return f"dialog:{dialog_id}:messages"
|
||||
|
||||
def _summary_key(self, dialog_id: str) -> str:
|
||||
return f"dialog:{dialog_id}:summary"
|
||||
|
||||
def _decode_messages(self, payload: str | None) -> list[dict[str, Any]]:
|
||||
if not payload:
|
||||
return []
|
||||
try:
|
||||
parsed = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
if not isinstance(parsed, list):
|
||||
return []
|
||||
return [item for item in parsed if isinstance(item, dict)]
|
||||
|
||||
def _fallback_summary(self, messages: list[dict[str, Any]]) -> str:
|
||||
chunks = [str(item.get("content", "")) for item in messages[-4:]]
|
||||
return "\n".join(chunk for chunk in chunks if chunk)
|
||||
Reference in New Issue
Block a user