upload
This commit is contained in:
@@ -0,0 +1,491 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, NamedTuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import Action, Capability
|
||||
from app.models.capability import CapabilityType
|
||||
|
||||
|
||||
class SelectedCapability(NamedTuple):
|
||||
capability: Capability
|
||||
score: float
|
||||
confidence_tier: str = "high"
|
||||
|
||||
|
||||
class SemanticSelectionService:
|
||||
HIGH_CONFIDENCE_THRESHOLD = 0.45
|
||||
MEDIUM_CONFIDENCE_THRESHOLD = 0.30
|
||||
LOW_MARGIN_THRESHOLD = 0.05
|
||||
CRM_TOKENS = {
|
||||
"crm",
|
||||
"segment",
|
||||
"segments",
|
||||
"audience",
|
||||
"campaign",
|
||||
"campaigns",
|
||||
"mailing",
|
||||
"newsletter",
|
||||
"lead",
|
||||
"leads",
|
||||
"retention",
|
||||
"cohort",
|
||||
"churn",
|
||||
"conversion",
|
||||
"promo",
|
||||
"offer",
|
||||
"offers",
|
||||
"email",
|
||||
"emails",
|
||||
"push",
|
||||
"sale",
|
||||
"sales",
|
||||
"сегмент",
|
||||
"сегменты",
|
||||
"аудитория",
|
||||
"кампания",
|
||||
"кампании",
|
||||
"рассылка",
|
||||
"лид",
|
||||
"лиды",
|
||||
"ретеншн",
|
||||
"конверсия",
|
||||
"оффер",
|
||||
"офферы",
|
||||
"пуш",
|
||||
"продажи",
|
||||
"клиент",
|
||||
"клиенты",
|
||||
}
|
||||
GENERIC_TOKENS = {
|
||||
"get",
|
||||
"list",
|
||||
"create",
|
||||
"update",
|
||||
"delete",
|
||||
"call",
|
||||
"data",
|
||||
"info",
|
||||
"items",
|
||||
"resource",
|
||||
"resources",
|
||||
"service",
|
||||
"api",
|
||||
"handle",
|
||||
"handler",
|
||||
"manage",
|
||||
"process",
|
||||
"method",
|
||||
"action",
|
||||
"fetch",
|
||||
"general",
|
||||
"common",
|
||||
"получить",
|
||||
"список",
|
||||
"создать",
|
||||
"обновить",
|
||||
"удалить",
|
||||
"данные",
|
||||
"инфо",
|
||||
"ресурс",
|
||||
"сервис",
|
||||
"метод",
|
||||
"действие",
|
||||
"общее",
|
||||
}
|
||||
_STOPWORDS = {
|
||||
"and",
|
||||
"the",
|
||||
"for",
|
||||
"with",
|
||||
"from",
|
||||
"into",
|
||||
"that",
|
||||
"this",
|
||||
"что",
|
||||
"это",
|
||||
"как",
|
||||
"для",
|
||||
"или",
|
||||
"при",
|
||||
"про",
|
||||
"надо",
|
||||
"нужно",
|
||||
"хочу",
|
||||
"build",
|
||||
"pipeline",
|
||||
"workflow",
|
||||
"scenario",
|
||||
"automation",
|
||||
"пайплайн",
|
||||
"сценарий",
|
||||
"автоматизация",
|
||||
"построй",
|
||||
"собери",
|
||||
}
|
||||
_ALIAS_EXPANSIONS = {
|
||||
"польз": {"user", "users", "client", "clients", "пользователь", "пользователи"},
|
||||
"клиент": {"client", "clients", "user", "users", "клиент", "клиенты"},
|
||||
"юзер": {"user", "users", "пользователь", "пользователи"},
|
||||
"получ": {"get", "fetch", "list", "retrieve", "получить", "список"},
|
||||
"спис": {"list", "get", "fetch", "список", "получить"},
|
||||
"созд": {"create", "add", "post", "создать"},
|
||||
"обнов": {"update", "patch", "put", "обновить"},
|
||||
"удал": {"delete", "remove", "del", "удалить"},
|
||||
"рассыл": {"mailing", "newsletter", "broadcast", "email", "рассылка"},
|
||||
"сегмент": {"segment", "segments", "сегмент", "сегменты"},
|
||||
"лид": {"lead", "leads", "лид", "лиды"},
|
||||
"отчет": {"report", "analytics", "отчет", "отчёт"},
|
||||
"отчёт": {"report", "analytics", "отчет", "отчёт"},
|
||||
"user": {"пользователь", "пользователи", "user", "users"},
|
||||
"users": {"пользователь", "пользователи", "user", "users"},
|
||||
"get": {"получить", "список", "get", "fetch", "list"},
|
||||
"fetch": {"получить", "список", "get", "fetch", "list"},
|
||||
"list": {"получить", "список", "get", "fetch", "list"},
|
||||
}
|
||||
|
||||
async def select_capabilities(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_query: str,
|
||||
owner_user_id: UUID | None = None,
|
||||
limit: int = 10,
|
||||
) -> list[SelectedCapability]:
|
||||
query_tokens = self._tokenize(user_query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
query = select(Capability).order_by(Capability.created_at.asc())
|
||||
if owner_user_id is not None:
|
||||
# User-scoped with legacy compatibility:
|
||||
# some old capabilities may have user_id=NULL while their source action has owner.
|
||||
query = query.outerjoin(Action, Capability.action_id == Action.id).where(
|
||||
or_(
|
||||
Capability.user_id == owner_user_id,
|
||||
and_(
|
||||
Capability.user_id.is_(None),
|
||||
Action.user_id == owner_user_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
query = query.limit(200)
|
||||
result = await session.execute(query)
|
||||
capabilities = list(result.scalars().all())
|
||||
|
||||
executable_capabilities = [
|
||||
capability
|
||||
for capability in capabilities
|
||||
if self._is_executable_capability(capability)
|
||||
]
|
||||
candidates = executable_capabilities
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
query_tokens_expanded = self._expand_tokens(query_tokens)
|
||||
ranked: list[SelectedCapability] = []
|
||||
for capability in candidates:
|
||||
score = self._score_capability(query_tokens, query_tokens_expanded, capability)
|
||||
if score <= 0:
|
||||
continue
|
||||
ranked.append(SelectedCapability(capability=capability, score=score))
|
||||
|
||||
ranked.sort(key=lambda item: item.score, reverse=True)
|
||||
if not ranked:
|
||||
if candidates:
|
||||
# Fallback: keep generation moving even when lexical matching is weak.
|
||||
return [
|
||||
SelectedCapability(
|
||||
capability=capability,
|
||||
score=0.01,
|
||||
confidence_tier="low",
|
||||
)
|
||||
for capability in candidates[:limit]
|
||||
]
|
||||
return []
|
||||
|
||||
top_score = ranked[0].score
|
||||
second_score = ranked[1].score if len(ranked) > 1 else 0.0
|
||||
margin = top_score - second_score
|
||||
confidence_tier = self._resolve_confidence_tier(top_score, margin)
|
||||
|
||||
return [
|
||||
SelectedCapability(
|
||||
capability=item.capability,
|
||||
score=item.score,
|
||||
confidence_tier=confidence_tier,
|
||||
)
|
||||
for item in ranked[:limit]
|
||||
]
|
||||
|
||||
def _score_capability(
|
||||
self,
|
||||
query_tokens: set[str],
|
||||
query_tokens_expanded: set[str],
|
||||
capability: Capability,
|
||||
) -> float:
|
||||
name = str(getattr(capability, "name", "") or "")
|
||||
description = str(getattr(capability, "description", "") or "")
|
||||
name_tokens = self._tokenize(name)
|
||||
description_tokens = self._tokenize(description)
|
||||
context_tokens = self._extract_context_tokens(capability)
|
||||
recipe_tokens = self._extract_recipe_tokens(capability)
|
||||
combined_tokens = name_tokens | description_tokens | context_tokens | recipe_tokens
|
||||
if not combined_tokens:
|
||||
return 0.0
|
||||
|
||||
combined_tokens_expanded = self._expand_tokens(combined_tokens)
|
||||
overlap = query_tokens_expanded & combined_tokens_expanded
|
||||
if not overlap:
|
||||
return 0.0
|
||||
|
||||
overlap_ratio = len(overlap) / len(query_tokens_expanded)
|
||||
name_tokens_expanded = self._expand_tokens(name_tokens)
|
||||
name_ratio = len(query_tokens_expanded & name_tokens_expanded) / len(query_tokens_expanded)
|
||||
exact_bonus = 0.22 if query_tokens_expanded <= combined_tokens_expanded else 0.0
|
||||
context_ratio = 0.0
|
||||
context_bonus = 0.0
|
||||
if context_tokens:
|
||||
context_tokens_expanded = self._expand_tokens(context_tokens)
|
||||
context_overlap = query_tokens_expanded & context_tokens_expanded
|
||||
context_ratio = len(context_overlap) / len(query_tokens_expanded)
|
||||
context_bonus = min(0.16, len(context_overlap) * 0.03)
|
||||
|
||||
generic_expanded = self._expand_tokens(self.GENERIC_TOKENS)
|
||||
entity_overlap = overlap - generic_expanded
|
||||
entity_bonus = min(0.18, len(entity_overlap) * 0.06) if entity_overlap else 0.0
|
||||
|
||||
query_crm_tokens = query_tokens_expanded & self.CRM_TOKENS
|
||||
capability_crm_tokens = combined_tokens_expanded & self.CRM_TOKENS
|
||||
crm_bonus = 0.0
|
||||
if query_crm_tokens and capability_crm_tokens:
|
||||
crm_overlap = len(query_crm_tokens & capability_crm_tokens)
|
||||
crm_bonus = 0.12 + min(0.14, crm_overlap * 0.04)
|
||||
|
||||
generic_penalty = self._generic_capability_penalty(combined_tokens)
|
||||
|
||||
return (
|
||||
max(overlap_ratio, name_ratio * 1.12, context_ratio * 0.95)
|
||||
+ exact_bonus
|
||||
+ context_bonus
|
||||
+ entity_bonus
|
||||
+ crm_bonus
|
||||
- generic_penalty
|
||||
)
|
||||
|
||||
def _extract_context_tokens(self, capability: Capability) -> set[str]:
|
||||
llm_payload = getattr(capability, "llm_payload", None)
|
||||
if not isinstance(llm_payload, dict):
|
||||
return set()
|
||||
|
||||
chunks: list[str] = []
|
||||
for key in (
|
||||
"action_context_brief",
|
||||
"openapi_hints",
|
||||
"action_context",
|
||||
"recipe_summary",
|
||||
"composite_context",
|
||||
):
|
||||
value = llm_payload.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
self._collect_text_chunks(value=value, chunks=chunks, depth=0, max_depth=4)
|
||||
|
||||
tokens: set[str] = set()
|
||||
for chunk in chunks[:120]:
|
||||
tokens.update(self._tokenize(chunk))
|
||||
return tokens
|
||||
|
||||
def _extract_recipe_tokens(self, capability: Capability) -> set[str]:
|
||||
recipe = getattr(capability, "recipe", None)
|
||||
if not isinstance(recipe, dict):
|
||||
return set()
|
||||
steps = recipe.get("steps")
|
||||
if not isinstance(steps, list):
|
||||
return set()
|
||||
|
||||
chunks: list[str] = []
|
||||
for raw_step in steps[:30]:
|
||||
if not isinstance(raw_step, dict):
|
||||
continue
|
||||
inputs = raw_step.get("inputs")
|
||||
if not isinstance(inputs, dict):
|
||||
continue
|
||||
for key, value in inputs.items():
|
||||
if isinstance(key, str):
|
||||
chunks.append(key)
|
||||
if isinstance(value, str):
|
||||
chunks.append(value)
|
||||
|
||||
tokens: set[str] = set()
|
||||
for chunk in chunks:
|
||||
tokens.update(self._tokenize(chunk))
|
||||
return tokens
|
||||
|
||||
def _collect_text_chunks(
|
||||
self,
|
||||
*,
|
||||
value: object,
|
||||
chunks: list[str],
|
||||
depth: int,
|
||||
max_depth: int,
|
||||
) -> None:
|
||||
if depth > max_depth or len(chunks) >= 120:
|
||||
return
|
||||
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if stripped:
|
||||
chunks.append(stripped)
|
||||
return
|
||||
|
||||
if isinstance(value, dict):
|
||||
preferred_keys = {
|
||||
"operation_id",
|
||||
"method",
|
||||
"path",
|
||||
"base_url",
|
||||
"summary",
|
||||
"description",
|
||||
"tags",
|
||||
"source_filename",
|
||||
"required_inputs",
|
||||
"request_content_types",
|
||||
"response_content_types",
|
||||
"response_status_codes",
|
||||
"security_requirements",
|
||||
"parameter_names_by_location",
|
||||
"path_segments",
|
||||
"input_signals",
|
||||
"output_signals",
|
||||
}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
if key not in preferred_keys:
|
||||
continue
|
||||
chunks.append(key)
|
||||
self._collect_text_chunks(
|
||||
value=item,
|
||||
chunks=chunks,
|
||||
depth=depth + 1,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(value, list):
|
||||
for item in value[:30]:
|
||||
self._collect_text_chunks(
|
||||
value=item,
|
||||
chunks=chunks,
|
||||
depth=depth + 1,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
||||
def _resolve_confidence_tier(self, top_score: float, margin: float) -> str:
|
||||
if margin < self.LOW_MARGIN_THRESHOLD:
|
||||
return "low"
|
||||
if top_score >= self.HIGH_CONFIDENCE_THRESHOLD:
|
||||
return "high"
|
||||
if top_score >= self.MEDIUM_CONFIDENCE_THRESHOLD:
|
||||
return "medium"
|
||||
return "low"
|
||||
|
||||
def _generic_capability_penalty(self, tokens: set[str]) -> float:
|
||||
if not tokens:
|
||||
return 0.0
|
||||
generic_share = len(tokens & self.GENERIC_TOKENS) / len(tokens)
|
||||
if generic_share >= 0.65:
|
||||
return 0.14
|
||||
if generic_share >= 0.5:
|
||||
return 0.09
|
||||
if generic_share >= 0.35:
|
||||
return 0.04
|
||||
return 0.0
|
||||
|
||||
def _tokenize(self, value: str) -> set[str]:
|
||||
tokens = set(re.findall(r"[a-zA-Zа-яА-Я0-9]+", value.lower()))
|
||||
return {
|
||||
token
|
||||
for token in tokens
|
||||
if len(token) >= 3 and token not in self._STOPWORDS
|
||||
}
|
||||
|
||||
def _is_executable_capability(self, capability: Capability) -> bool:
|
||||
cap_type = self._capability_type_value(capability)
|
||||
if cap_type == CapabilityType.ATOMIC.value:
|
||||
return getattr(capability, "action_id", None) is not None
|
||||
if cap_type == CapabilityType.COMPOSITE.value:
|
||||
return self._recipe_is_executable(getattr(capability, "recipe", None))
|
||||
return False
|
||||
|
||||
def _recipe_is_executable(self, recipe: Any) -> bool:
|
||||
if not isinstance(recipe, dict):
|
||||
return False
|
||||
if recipe.get("version") != 1:
|
||||
return False
|
||||
steps = recipe.get("steps")
|
||||
return isinstance(steps, list) and bool(steps)
|
||||
|
||||
def _capability_type_value(self, capability: Capability) -> str:
|
||||
raw = getattr(capability, "type", None)
|
||||
if isinstance(raw, CapabilityType):
|
||||
return raw.value
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if hasattr(raw, "value"):
|
||||
return str(raw.value)
|
||||
return CapabilityType.ATOMIC.value
|
||||
|
||||
def _expand_tokens(self, tokens: set[str]) -> set[str]:
|
||||
expanded: set[str] = set()
|
||||
for token in tokens:
|
||||
expanded.add(token)
|
||||
normalized_variants = self._normalized_variants(token)
|
||||
expanded.update(normalized_variants)
|
||||
for variant in normalized_variants | {token}:
|
||||
for key, aliases in self._ALIAS_EXPANSIONS.items():
|
||||
if variant == key or variant.startswith(key):
|
||||
expanded.update(aliases)
|
||||
return expanded
|
||||
|
||||
def _normalized_variants(self, token: str) -> set[str]:
|
||||
variants = {token}
|
||||
if len(token) >= 5:
|
||||
for suffix in (
|
||||
"иями",
|
||||
"ями",
|
||||
"ами",
|
||||
"ов",
|
||||
"ев",
|
||||
"ей",
|
||||
"ам",
|
||||
"ям",
|
||||
"ах",
|
||||
"ях",
|
||||
"ые",
|
||||
"ий",
|
||||
"ый",
|
||||
"ая",
|
||||
"ое",
|
||||
"ой",
|
||||
"а",
|
||||
"я",
|
||||
"ы",
|
||||
"и",
|
||||
"у",
|
||||
"ю",
|
||||
"е",
|
||||
"о",
|
||||
):
|
||||
if token.endswith(suffix) and len(token) > len(suffix) + 2:
|
||||
variants.add(token[: -len(suffix)])
|
||||
|
||||
if token.endswith("ies") and len(token) > 4:
|
||||
variants.add(token[:-3] + "y")
|
||||
if token.endswith("s") and len(token) > 3:
|
||||
variants.add(token[:-1])
|
||||
return variants
|
||||
Reference in New Issue
Block a user