from __future__ import annotations import re from typing import Any, NamedTuple from uuid import UUID from sqlalchemy import and_, or_, select from sqlalchemy.ext.asyncio import AsyncSession from app.models import Action, Capability from app.models.capability import CapabilityType class SelectedCapability(NamedTuple): capability: Capability score: float confidence_tier: str = "high" class SemanticSelectionService: HIGH_CONFIDENCE_THRESHOLD = 0.45 MEDIUM_CONFIDENCE_THRESHOLD = 0.30 LOW_MARGIN_THRESHOLD = 0.05 CRM_TOKENS = { "crm", "segment", "segments", "audience", "campaign", "campaigns", "mailing", "newsletter", "lead", "leads", "retention", "cohort", "churn", "conversion", "promo", "offer", "offers", "email", "emails", "push", "sale", "sales", "сегмент", "сегменты", "аудитория", "кампания", "кампании", "рассылка", "лид", "лиды", "ретеншн", "конверсия", "оффер", "офферы", "пуш", "продажи", "клиент", "клиенты", } GENERIC_TOKENS = { "get", "list", "create", "update", "delete", "call", "data", "info", "items", "resource", "resources", "service", "api", "handle", "handler", "manage", "process", "method", "action", "fetch", "general", "common", "получить", "список", "создать", "обновить", "удалить", "данные", "инфо", "ресурс", "сервис", "метод", "действие", "общее", } _STOPWORDS = { "and", "the", "for", "with", "from", "into", "that", "this", "что", "это", "как", "для", "или", "при", "про", "надо", "нужно", "хочу", "build", "pipeline", "workflow", "scenario", "automation", "пайплайн", "сценарий", "автоматизация", "построй", "собери", } _ALIAS_EXPANSIONS = { "польз": {"user", "users", "client", "clients", "пользователь", "пользователи"}, "клиент": {"client", "clients", "user", "users", "клиент", "клиенты"}, "юзер": {"user", "users", "пользователь", "пользователи"}, "получ": {"get", "fetch", "list", "retrieve", "получить", "список"}, "спис": {"list", "get", "fetch", "список", "получить"}, "созд": {"create", "add", "post", "создать"}, "обнов": {"update", "patch", "put", "обновить"}, "удал": {"delete", "remove", "del", "удалить"}, "рассыл": {"mailing", "newsletter", "broadcast", "email", "рассылка"}, "сегмент": {"segment", "segments", "сегмент", "сегменты"}, "лид": {"lead", "leads", "лид", "лиды"}, "отчет": {"report", "analytics", "отчет", "отчёт"}, "отчёт": {"report", "analytics", "отчет", "отчёт"}, "user": {"пользователь", "пользователи", "user", "users"}, "users": {"пользователь", "пользователи", "user", "users"}, "get": {"получить", "список", "get", "fetch", "list"}, "fetch": {"получить", "список", "get", "fetch", "list"}, "list": {"получить", "список", "get", "fetch", "list"}, } async def select_capabilities( self, session: AsyncSession, user_query: str, owner_user_id: UUID | None = None, limit: int = 10, ) -> list[SelectedCapability]: query_tokens = self._tokenize(user_query) if not query_tokens: return [] query = select(Capability).order_by(Capability.created_at.asc()) if owner_user_id is not None: # User-scoped with legacy compatibility: # some old capabilities may have user_id=NULL while their source action has owner. query = query.outerjoin(Action, Capability.action_id == Action.id).where( or_( Capability.user_id == owner_user_id, and_( Capability.user_id.is_(None), Action.user_id == owner_user_id, ), ) ) query = query.limit(200) result = await session.execute(query) capabilities = list(result.scalars().all()) executable_capabilities = [ capability for capability in capabilities if self._is_executable_capability(capability) ] candidates = executable_capabilities if not candidates: return [] query_tokens_expanded = self._expand_tokens(query_tokens) ranked: list[SelectedCapability] = [] for capability in candidates: score = self._score_capability(query_tokens, query_tokens_expanded, capability) if score <= 0: continue ranked.append(SelectedCapability(capability=capability, score=score)) ranked.sort(key=lambda item: item.score, reverse=True) if not ranked: if candidates: # Fallback: keep generation moving even when lexical matching is weak. return [ SelectedCapability( capability=capability, score=0.01, confidence_tier="low", ) for capability in candidates[:limit] ] return [] top_score = ranked[0].score second_score = ranked[1].score if len(ranked) > 1 else 0.0 margin = top_score - second_score confidence_tier = self._resolve_confidence_tier(top_score, margin) return [ SelectedCapability( capability=item.capability, score=item.score, confidence_tier=confidence_tier, ) for item in ranked[:limit] ] def _score_capability( self, query_tokens: set[str], query_tokens_expanded: set[str], capability: Capability, ) -> float: name = str(getattr(capability, "name", "") or "") description = str(getattr(capability, "description", "") or "") name_tokens = self._tokenize(name) description_tokens = self._tokenize(description) context_tokens = self._extract_context_tokens(capability) recipe_tokens = self._extract_recipe_tokens(capability) combined_tokens = name_tokens | description_tokens | context_tokens | recipe_tokens if not combined_tokens: return 0.0 combined_tokens_expanded = self._expand_tokens(combined_tokens) overlap = query_tokens_expanded & combined_tokens_expanded if not overlap: return 0.0 overlap_ratio = len(overlap) / len(query_tokens_expanded) name_tokens_expanded = self._expand_tokens(name_tokens) name_ratio = len(query_tokens_expanded & name_tokens_expanded) / len(query_tokens_expanded) exact_bonus = 0.22 if query_tokens_expanded <= combined_tokens_expanded else 0.0 context_ratio = 0.0 context_bonus = 0.0 if context_tokens: context_tokens_expanded = self._expand_tokens(context_tokens) context_overlap = query_tokens_expanded & context_tokens_expanded context_ratio = len(context_overlap) / len(query_tokens_expanded) context_bonus = min(0.16, len(context_overlap) * 0.03) generic_expanded = self._expand_tokens(self.GENERIC_TOKENS) entity_overlap = overlap - generic_expanded entity_bonus = min(0.18, len(entity_overlap) * 0.06) if entity_overlap else 0.0 query_crm_tokens = query_tokens_expanded & self.CRM_TOKENS capability_crm_tokens = combined_tokens_expanded & self.CRM_TOKENS crm_bonus = 0.0 if query_crm_tokens and capability_crm_tokens: crm_overlap = len(query_crm_tokens & capability_crm_tokens) crm_bonus = 0.12 + min(0.14, crm_overlap * 0.04) generic_penalty = self._generic_capability_penalty(combined_tokens) return ( max(overlap_ratio, name_ratio * 1.12, context_ratio * 0.95) + exact_bonus + context_bonus + entity_bonus + crm_bonus - generic_penalty ) def _extract_context_tokens(self, capability: Capability) -> set[str]: llm_payload = getattr(capability, "llm_payload", None) if not isinstance(llm_payload, dict): return set() chunks: list[str] = [] for key in ( "action_context_brief", "openapi_hints", "action_context", "recipe_summary", "composite_context", ): value = llm_payload.get(key) if value is None: continue self._collect_text_chunks(value=value, chunks=chunks, depth=0, max_depth=4) tokens: set[str] = set() for chunk in chunks[:120]: tokens.update(self._tokenize(chunk)) return tokens def _extract_recipe_tokens(self, capability: Capability) -> set[str]: recipe = getattr(capability, "recipe", None) if not isinstance(recipe, dict): return set() steps = recipe.get("steps") if not isinstance(steps, list): return set() chunks: list[str] = [] for raw_step in steps[:30]: if not isinstance(raw_step, dict): continue inputs = raw_step.get("inputs") if not isinstance(inputs, dict): continue for key, value in inputs.items(): if isinstance(key, str): chunks.append(key) if isinstance(value, str): chunks.append(value) tokens: set[str] = set() for chunk in chunks: tokens.update(self._tokenize(chunk)) return tokens def _collect_text_chunks( self, *, value: object, chunks: list[str], depth: int, max_depth: int, ) -> None: if depth > max_depth or len(chunks) >= 120: return if isinstance(value, str): stripped = value.strip() if stripped: chunks.append(stripped) return if isinstance(value, dict): preferred_keys = { "operation_id", "method", "path", "base_url", "summary", "description", "tags", "source_filename", "required_inputs", "request_content_types", "response_content_types", "response_status_codes", "security_requirements", "parameter_names_by_location", "path_segments", "input_signals", "output_signals", } for key, item in value.items(): if not isinstance(key, str): continue if key not in preferred_keys: continue chunks.append(key) self._collect_text_chunks( value=item, chunks=chunks, depth=depth + 1, max_depth=max_depth, ) return if isinstance(value, list): for item in value[:30]: self._collect_text_chunks( value=item, chunks=chunks, depth=depth + 1, max_depth=max_depth, ) def _resolve_confidence_tier(self, top_score: float, margin: float) -> str: if margin < self.LOW_MARGIN_THRESHOLD: return "low" if top_score >= self.HIGH_CONFIDENCE_THRESHOLD: return "high" if top_score >= self.MEDIUM_CONFIDENCE_THRESHOLD: return "medium" return "low" def _generic_capability_penalty(self, tokens: set[str]) -> float: if not tokens: return 0.0 generic_share = len(tokens & self.GENERIC_TOKENS) / len(tokens) if generic_share >= 0.65: return 0.14 if generic_share >= 0.5: return 0.09 if generic_share >= 0.35: return 0.04 return 0.0 def _tokenize(self, value: str) -> set[str]: tokens = set(re.findall(r"[a-zA-Zа-яА-Я0-9]+", value.lower())) return { token for token in tokens if len(token) >= 3 and token not in self._STOPWORDS } def _is_executable_capability(self, capability: Capability) -> bool: cap_type = self._capability_type_value(capability) if cap_type == CapabilityType.ATOMIC.value: return getattr(capability, "action_id", None) is not None if cap_type == CapabilityType.COMPOSITE.value: return self._recipe_is_executable(getattr(capability, "recipe", None)) return False def _recipe_is_executable(self, recipe: Any) -> bool: if not isinstance(recipe, dict): return False if recipe.get("version") != 1: return False steps = recipe.get("steps") return isinstance(steps, list) and bool(steps) def _capability_type_value(self, capability: Capability) -> str: raw = getattr(capability, "type", None) if isinstance(raw, CapabilityType): return raw.value if isinstance(raw, str): return raw if hasattr(raw, "value"): return str(raw.value) return CapabilityType.ATOMIC.value def _expand_tokens(self, tokens: set[str]) -> set[str]: expanded: set[str] = set() for token in tokens: expanded.add(token) normalized_variants = self._normalized_variants(token) expanded.update(normalized_variants) for variant in normalized_variants | {token}: for key, aliases in self._ALIAS_EXPANSIONS.items(): if variant == key or variant.startswith(key): expanded.update(aliases) return expanded def _normalized_variants(self, token: str) -> set[str]: variants = {token} if len(token) >= 5: for suffix in ( "иями", "ями", "ами", "ов", "ев", "ей", "ам", "ям", "ах", "ях", "ые", "ий", "ый", "ая", "ое", "ой", "а", "я", "ы", "и", "у", "ю", "е", "о", ): if token.endswith(suffix) and len(token) > len(suffix) + 2: variants.add(token[: -len(suffix)]) if token.endswith("ies") and len(token) > 4: variants.add(token[:-3] + "y") if token.endswith("s") and len(token) > 3: variants.add(token[:-1]) return variants