upload
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from app.services.openapi_service import OpenAPIService
|
||||
from app.services.capability_service import CapabilityService
|
||||
from app.services.execution_service import ExecutionService
|
||||
from app.services.pipeline_service import PipelineService
|
||||
|
||||
__all__ = [
|
||||
"OpenAPIService",
|
||||
"CapabilityService",
|
||||
"ExecutionService",
|
||||
"PipelineService",
|
||||
]
|
||||
@@ -0,0 +1,758 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import Action, Capability
|
||||
from app.models.capability import CapabilityType
|
||||
|
||||
|
||||
class CompositeRecipeValidationError(ValueError):
|
||||
def __init__(self, errors: list[str]) -> None:
|
||||
self.errors = errors
|
||||
super().__init__("; ".join(errors))
|
||||
|
||||
|
||||
class CapabilityService:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
@staticmethod
|
||||
def build_from_actions(
|
||||
actions: list[Action],
|
||||
*,
|
||||
owner_user_id: UUID,
|
||||
) -> list[Capability]:
|
||||
capabilities: list[Capability] = []
|
||||
for action in actions:
|
||||
capability_payload = CapabilityService._build_capability_payload(action)
|
||||
capabilities.append(
|
||||
Capability(
|
||||
user_id=owner_user_id,
|
||||
action_id=action.id,
|
||||
type=CapabilityType.ATOMIC,
|
||||
name=capability_payload["name"],
|
||||
description=capability_payload.get("description"),
|
||||
input_schema=capability_payload.get("input_schema"),
|
||||
output_schema=capability_payload.get("output_schema"),
|
||||
data_format=capability_payload.get("data_format"),
|
||||
llm_payload=capability_payload.get("llm_payload"),
|
||||
)
|
||||
)
|
||||
|
||||
return capabilities
|
||||
|
||||
async def create_composite_capability(
|
||||
self,
|
||||
*,
|
||||
owner_user_id: UUID,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
output_schema: dict[str, Any] | None = None,
|
||||
recipe: dict[str, Any],
|
||||
llm_payload: dict[str, Any] | None = None,
|
||||
data_format: dict[str, Any] | None = None,
|
||||
) -> Capability:
|
||||
capability = Capability(
|
||||
user_id=owner_user_id,
|
||||
type=CapabilityType.COMPOSITE,
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
recipe=recipe,
|
||||
llm_payload=llm_payload,
|
||||
data_format=data_format,
|
||||
)
|
||||
self.session.add(capability)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(capability)
|
||||
return capability
|
||||
|
||||
async def create_validated_composite_capability(
|
||||
self,
|
||||
*,
|
||||
owner_user_id: UUID,
|
||||
name: str,
|
||||
description: str | None = None,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
output_schema: dict[str, Any] | None = None,
|
||||
recipe: dict[str, Any],
|
||||
include_all: bool = False,
|
||||
) -> Capability:
|
||||
normalized_recipe, step_capabilities = await self.validate_composite_recipe(
|
||||
recipe=recipe,
|
||||
owner_user_id=owner_user_id,
|
||||
include_all=include_all,
|
||||
)
|
||||
llm_payload = self._build_composite_llm_payload(step_capabilities)
|
||||
data_format = {
|
||||
"request_schema_type": input_schema.get("type")
|
||||
if isinstance(input_schema, dict)
|
||||
else None,
|
||||
"response_schema_types": [output_schema.get("type")]
|
||||
if isinstance(output_schema, dict)
|
||||
and isinstance(output_schema.get("type"), str)
|
||||
else [],
|
||||
"composite": {
|
||||
"version": normalized_recipe.get("version"),
|
||||
"steps_count": len(normalized_recipe.get("steps", [])),
|
||||
"step_capability_names": [
|
||||
str(getattr(capability, "name", ""))
|
||||
for capability in step_capabilities
|
||||
],
|
||||
},
|
||||
}
|
||||
return await self.create_composite_capability(
|
||||
owner_user_id=owner_user_id,
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
recipe=normalized_recipe,
|
||||
llm_payload=llm_payload,
|
||||
data_format=data_format,
|
||||
)
|
||||
|
||||
async def validate_composite_recipe(
|
||||
self,
|
||||
*,
|
||||
recipe: dict[str, Any],
|
||||
owner_user_id: UUID,
|
||||
include_all: bool = False,
|
||||
) -> tuple[dict[str, Any], list[Capability]]:
|
||||
errors: list[str] = []
|
||||
if not isinstance(recipe, dict):
|
||||
raise CompositeRecipeValidationError(["recipe must be an object"])
|
||||
|
||||
version = recipe.get("version")
|
||||
if version != 1:
|
||||
errors.append("recipe.version must be 1")
|
||||
|
||||
raw_steps = recipe.get("steps")
|
||||
if not isinstance(raw_steps, list) or not raw_steps:
|
||||
errors.append("recipe.steps must be a non-empty list")
|
||||
raise CompositeRecipeValidationError(errors)
|
||||
|
||||
normalized_steps: list[dict[str, Any]] = []
|
||||
seen_step_numbers: set[int] = set()
|
||||
for index, raw_step in enumerate(raw_steps):
|
||||
if not isinstance(raw_step, dict):
|
||||
errors.append(f"recipe.steps[{index}] must be an object")
|
||||
continue
|
||||
|
||||
step_number = raw_step.get("step")
|
||||
if not isinstance(step_number, int) or step_number < 1:
|
||||
errors.append(f"recipe.steps[{index}].step must be positive integer")
|
||||
continue
|
||||
|
||||
if step_number in seen_step_numbers:
|
||||
errors.append(f"recipe.steps[{index}].step duplicates step {step_number}")
|
||||
seen_step_numbers.add(step_number)
|
||||
|
||||
capability_uuid = self._to_uuid(raw_step.get("capability_id"))
|
||||
if capability_uuid is None:
|
||||
errors.append(f"recipe.steps[{index}].capability_id must be UUID")
|
||||
continue
|
||||
|
||||
raw_inputs = raw_step.get("inputs", {})
|
||||
if raw_inputs is None:
|
||||
raw_inputs = {}
|
||||
if not isinstance(raw_inputs, dict):
|
||||
errors.append(f"recipe.steps[{index}].inputs must be an object")
|
||||
raw_inputs = {}
|
||||
|
||||
normalized_inputs: dict[str, str] = {}
|
||||
for input_name, binding in raw_inputs.items():
|
||||
if not isinstance(input_name, str) or not input_name.strip():
|
||||
errors.append(f"recipe.steps[{index}].inputs has invalid key")
|
||||
continue
|
||||
if not isinstance(binding, str):
|
||||
errors.append(
|
||||
f"recipe.steps[{index}].inputs.{input_name} must be string binding"
|
||||
)
|
||||
continue
|
||||
normalized_binding = binding.strip()
|
||||
if not normalized_binding:
|
||||
errors.append(
|
||||
f"recipe.steps[{index}].inputs.{input_name} must be non-empty binding"
|
||||
)
|
||||
continue
|
||||
if not self._is_supported_binding_expression(normalized_binding):
|
||||
errors.append(
|
||||
f"recipe.steps[{index}].inputs.{input_name} has unsupported binding '{normalized_binding}'"
|
||||
)
|
||||
continue
|
||||
normalized_inputs[input_name] = normalized_binding
|
||||
|
||||
normalized_steps.append(
|
||||
{
|
||||
"step": step_number,
|
||||
"capability_id": str(capability_uuid),
|
||||
"inputs": normalized_inputs,
|
||||
}
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise CompositeRecipeValidationError(errors)
|
||||
|
||||
normalized_steps.sort(key=lambda item: item["step"])
|
||||
for idx in range(1, len(normalized_steps)):
|
||||
if normalized_steps[idx]["step"] <= normalized_steps[idx - 1]["step"]:
|
||||
errors.append("recipe.steps must be strictly increasing by step")
|
||||
break
|
||||
|
||||
known_steps = {item["step"] for item in normalized_steps}
|
||||
for item in normalized_steps:
|
||||
for binding in item["inputs"].values():
|
||||
if not binding.startswith("$step."):
|
||||
continue
|
||||
source_step = self._extract_binding_source_step(binding)
|
||||
if source_step is None:
|
||||
errors.append(
|
||||
f"step {item['step']}: invalid step binding '{binding}'"
|
||||
)
|
||||
continue
|
||||
if source_step not in known_steps:
|
||||
errors.append(
|
||||
f"step {item['step']}: binding references missing step {source_step}"
|
||||
)
|
||||
continue
|
||||
if source_step >= item["step"]:
|
||||
errors.append(
|
||||
f"step {item['step']}: binding references non-previous step {source_step}"
|
||||
)
|
||||
|
||||
capability_ids = [UUID(item["capability_id"]) for item in normalized_steps]
|
||||
capabilities = await self.get_capabilities(
|
||||
capability_ids=capability_ids,
|
||||
owner_user_id=owner_user_id,
|
||||
include_all=include_all,
|
||||
)
|
||||
capabilities_by_id = {str(item.id): item for item in capabilities}
|
||||
for item in normalized_steps:
|
||||
capability = capabilities_by_id.get(item["capability_id"])
|
||||
if capability is None:
|
||||
errors.append(
|
||||
f"step {item['step']}: capability {item['capability_id']} not found or not accessible"
|
||||
)
|
||||
continue
|
||||
|
||||
capability_type = self._capability_type_value(capability)
|
||||
if capability_type != CapabilityType.ATOMIC.value:
|
||||
errors.append(
|
||||
f"step {item['step']}: nested composite is not allowed ({item['capability_id']})"
|
||||
)
|
||||
continue
|
||||
if getattr(capability, "action_id", None) is None:
|
||||
errors.append(
|
||||
f"step {item['step']}: atomic capability {item['capability_id']} has no action_id"
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise CompositeRecipeValidationError(errors)
|
||||
|
||||
normalized_recipe = {
|
||||
"version": 1,
|
||||
"steps": normalized_steps,
|
||||
}
|
||||
ordered_caps = [
|
||||
capabilities_by_id[item["capability_id"]]
|
||||
for item in normalized_steps
|
||||
if item["capability_id"] in capabilities_by_id
|
||||
]
|
||||
return normalized_recipe, ordered_caps
|
||||
|
||||
async def create_from_actions(
|
||||
self,
|
||||
actions: list[Action],
|
||||
*,
|
||||
owner_user_id: UUID,
|
||||
refresh: bool = True,
|
||||
) -> list[Capability]:
|
||||
capabilities = self.build_from_actions(actions, owner_user_id=owner_user_id)
|
||||
if not capabilities:
|
||||
return []
|
||||
|
||||
self.session.add_all(capabilities)
|
||||
await self.session.flush()
|
||||
|
||||
if refresh:
|
||||
for capability in capabilities:
|
||||
await self.session.refresh(capability)
|
||||
|
||||
return capabilities
|
||||
|
||||
async def get_capabilities(
|
||||
self,
|
||||
*,
|
||||
capability_ids: list[UUID] | None = None,
|
||||
action_ids: list[UUID] | None = None,
|
||||
owner_user_id: UUID | None = None,
|
||||
include_all: bool = False,
|
||||
limit: int | None = None,
|
||||
offset: int = 0,
|
||||
) -> list[Capability]:
|
||||
query = select(Capability).order_by(Capability.created_at.asc())
|
||||
|
||||
if not include_all and owner_user_id is not None:
|
||||
# Legacy compatibility: some old rows may have user_id=NULL while action is user-owned.
|
||||
query = query.outerjoin(Action, Capability.action_id == Action.id).where(
|
||||
or_(
|
||||
Capability.user_id == owner_user_id,
|
||||
and_(
|
||||
Capability.user_id.is_(None),
|
||||
Action.user_id == owner_user_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if capability_ids:
|
||||
query = query.where(Capability.id.in_(capability_ids))
|
||||
|
||||
if action_ids:
|
||||
query = query.where(Capability.action_id.in_(action_ids))
|
||||
|
||||
if offset:
|
||||
query = query.offset(offset)
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_capability(
|
||||
self,
|
||||
capability_id: UUID,
|
||||
*,
|
||||
owner_user_id: UUID | None = None,
|
||||
include_all: bool = False,
|
||||
) -> Capability | None:
|
||||
query = select(Capability).where(Capability.id == capability_id)
|
||||
if not include_all and owner_user_id is not None:
|
||||
query = query.outerjoin(Action, Capability.action_id == Action.id).where(
|
||||
or_(
|
||||
Capability.user_id == owner_user_id,
|
||||
and_(
|
||||
Capability.user_id.is_(None),
|
||||
Action.user_id == owner_user_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
def _is_supported_binding_expression(value: str) -> bool:
|
||||
if re.fullmatch(r"\$run\.[A-Za-z0-9_][A-Za-z0-9_\.]*", value):
|
||||
return True
|
||||
if re.fullmatch(r"\$step\.\d+\.[A-Za-z0-9_][A-Za-z0-9_\.]*", value):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _extract_binding_source_step(value: str) -> int | None:
|
||||
match = re.fullmatch(r"\$step\.(\d+)\.[A-Za-z0-9_][A-Za-z0-9_\.]*", value)
|
||||
if not match:
|
||||
return None
|
||||
return int(match.group(1))
|
||||
|
||||
@staticmethod
|
||||
def _to_uuid(value: Any) -> UUID | None:
|
||||
try:
|
||||
return UUID(str(value))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _capability_type_value(capability: Capability) -> str:
|
||||
cap_type = getattr(capability, "type", None)
|
||||
if isinstance(cap_type, CapabilityType):
|
||||
return cap_type.value
|
||||
if isinstance(cap_type, str):
|
||||
return cap_type
|
||||
if hasattr(cap_type, "value"):
|
||||
return str(cap_type.value)
|
||||
return CapabilityType.ATOMIC.value
|
||||
|
||||
@staticmethod
|
||||
def _build_composite_llm_payload(step_capabilities: list[Capability]) -> dict[str, Any]:
|
||||
step_names = [
|
||||
str(getattr(capability, "name", "") or "")
|
||||
for capability in step_capabilities
|
||||
if str(getattr(capability, "name", "") or "").strip()
|
||||
]
|
||||
return {
|
||||
"source": "composite",
|
||||
"recipe_summary": {
|
||||
"steps_count": len(step_capabilities),
|
||||
"step_names": step_names,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_capability_payload(action: Action) -> dict[str, Any]:
|
||||
input_schema = CapabilityService._build_input_schema(action)
|
||||
output_schema = getattr(action, "response_schema", None)
|
||||
data_format = CapabilityService._build_data_format(action)
|
||||
action_context = CapabilityService._build_action_context(
|
||||
action=action,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
data_format=data_format,
|
||||
)
|
||||
openapi_hints = CapabilityService._build_openapi_hints(
|
||||
action=action,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
)
|
||||
return {
|
||||
"name": CapabilityService._build_capability_name(action),
|
||||
"description": CapabilityService._build_capability_description(action),
|
||||
"input_schema": input_schema,
|
||||
"output_schema": output_schema,
|
||||
"data_format": data_format,
|
||||
"llm_payload": {
|
||||
"source": "deterministic",
|
||||
"action_context_version": "v2",
|
||||
"action_context": action_context,
|
||||
"action_context_brief": CapabilityService._build_action_context_brief(
|
||||
action_context=action_context,
|
||||
openapi_hints=openapi_hints,
|
||||
),
|
||||
"openapi_hints": openapi_hints,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_action_context(
|
||||
*,
|
||||
action: Action,
|
||||
input_schema: dict[str, Any] | None,
|
||||
output_schema: dict[str, Any] | None,
|
||||
data_format: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
method = getattr(action, "method", None)
|
||||
method_value = method.value if hasattr(method, "value") else str(method or "")
|
||||
parameter_names = CapabilityService._extract_parameter_names_by_location(
|
||||
getattr(action, "parameters_schema", None)
|
||||
)
|
||||
request_property_names = CapabilityService._extract_schema_property_names(
|
||||
getattr(action, "request_body_schema", None)
|
||||
)
|
||||
response_property_names = CapabilityService._extract_schema_property_names(
|
||||
getattr(action, "response_schema", None)
|
||||
)
|
||||
|
||||
return {
|
||||
"action_id": str(getattr(action, "id", "")),
|
||||
"operation_id": getattr(action, "operation_id", None),
|
||||
"method": method_value,
|
||||
"path": getattr(action, "path", None),
|
||||
"base_url": getattr(action, "base_url", None),
|
||||
"summary": getattr(action, "summary", None),
|
||||
"description": getattr(action, "description", None),
|
||||
"tags": getattr(action, "tags", None) or [],
|
||||
"source_filename": getattr(action, "source_filename", None),
|
||||
"input_schema": input_schema,
|
||||
"output_schema": output_schema,
|
||||
"parameters_schema": getattr(action, "parameters_schema", None),
|
||||
"request_body_schema": getattr(action, "request_body_schema", None),
|
||||
"response_schema": getattr(action, "response_schema", None),
|
||||
"raw_spec": getattr(action, "raw_spec", None),
|
||||
"data_format": data_format,
|
||||
"input_signals": {
|
||||
"required_inputs": CapabilityService._extract_required_inputs(input_schema),
|
||||
"parameter_names_by_location": parameter_names,
|
||||
"request_property_names": request_property_names,
|
||||
},
|
||||
"output_signals": {
|
||||
"response_property_names": response_property_names,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_openapi_hints(
|
||||
*,
|
||||
action: Action,
|
||||
input_schema: dict[str, Any] | None,
|
||||
output_schema: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
raw_spec = getattr(action, "raw_spec", None)
|
||||
if not isinstance(raw_spec, dict):
|
||||
raw_spec = {}
|
||||
|
||||
request_content_types = CapabilityService._extract_content_types_from_request(raw_spec)
|
||||
response_status_codes, response_content_types = (
|
||||
CapabilityService._extract_response_hints(raw_spec)
|
||||
)
|
||||
security_requirements = (
|
||||
raw_spec.get("security") if isinstance(raw_spec.get("security"), list) else []
|
||||
)
|
||||
parameter_names = CapabilityService._extract_parameter_names_by_location(
|
||||
getattr(action, "parameters_schema", None)
|
||||
)
|
||||
vendor_extensions = {
|
||||
key: value
|
||||
for key, value in raw_spec.items()
|
||||
if isinstance(key, str) and key.startswith("x-")
|
||||
}
|
||||
path_value = str(getattr(action, "path", "") or "")
|
||||
path_segments = [
|
||||
segment
|
||||
for segment in path_value.strip("/").split("/")
|
||||
if segment and not segment.startswith("{")
|
||||
]
|
||||
|
||||
return {
|
||||
"deprecated": bool(raw_spec.get("deprecated")),
|
||||
"security_requirements": security_requirements,
|
||||
"request_content_types": request_content_types,
|
||||
"response_content_types": response_content_types,
|
||||
"response_status_codes": response_status_codes,
|
||||
"has_request_body": bool(getattr(action, "request_body_schema", None)),
|
||||
"has_response_body": bool(output_schema),
|
||||
"required_inputs": CapabilityService._extract_required_inputs(input_schema),
|
||||
"parameter_names_by_location": parameter_names,
|
||||
"path_segments": path_segments,
|
||||
"tags": getattr(action, "tags", None) or [],
|
||||
"vendor_extensions": vendor_extensions,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_action_context_brief(
|
||||
*,
|
||||
action_context: dict[str, Any],
|
||||
openapi_hints: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"operation_id": action_context.get("operation_id"),
|
||||
"method": action_context.get("method"),
|
||||
"path": action_context.get("path"),
|
||||
"base_url": action_context.get("base_url"),
|
||||
"summary": action_context.get("summary"),
|
||||
"description": action_context.get("description"),
|
||||
"tags": action_context.get("tags") or [],
|
||||
"required_inputs": (action_context.get("input_signals") or {}).get("required_inputs") or [],
|
||||
"parameter_names_by_location": (action_context.get("input_signals") or {}).get(
|
||||
"parameter_names_by_location"
|
||||
)
|
||||
or {},
|
||||
"request_content_types": openapi_hints.get("request_content_types") or [],
|
||||
"response_content_types": openapi_hints.get("response_content_types") or [],
|
||||
"response_status_codes": openapi_hints.get("response_status_codes") or [],
|
||||
"security_requirements": openapi_hints.get("security_requirements") or [],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_capability_name(action: Action) -> str:
|
||||
operation_id = getattr(action, "operation_id", None)
|
||||
if operation_id:
|
||||
return str(operation_id)
|
||||
|
||||
method = getattr(action, "method", None)
|
||||
method_value = method.value.lower() if method is not None else "call"
|
||||
path = getattr(action, "path", "") or ""
|
||||
normalized_path = re.sub(r"[{}]", "", path).strip("/")
|
||||
normalized_path = re.sub(r"[^a-zA-Z0-9/]+", "_", normalized_path)
|
||||
normalized_path = normalized_path.replace("/", "_") or "root"
|
||||
return f"{method_value}_{normalized_path.lower()}"
|
||||
|
||||
@staticmethod
|
||||
def _build_capability_description(action: Action) -> str:
|
||||
summary = getattr(action, "summary", None)
|
||||
description = getattr(action, "description", None)
|
||||
operation_id = getattr(action, "operation_id", None)
|
||||
return str(
|
||||
summary
|
||||
or description
|
||||
or operation_id
|
||||
or CapabilityService._build_capability_name(action)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_input_schema(action: Action) -> dict[str, Any] | None:
|
||||
parameters_schema = getattr(action, "parameters_schema", None)
|
||||
request_body_schema = getattr(action, "request_body_schema", None)
|
||||
|
||||
if parameters_schema and request_body_schema:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"parameters": parameters_schema,
|
||||
"request_body": request_body_schema,
|
||||
},
|
||||
}
|
||||
if parameters_schema:
|
||||
return parameters_schema
|
||||
if request_body_schema:
|
||||
return request_body_schema
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _build_data_format(action: Action) -> dict[str, Any]:
|
||||
parameters_schema = getattr(action, "parameters_schema", None) or {}
|
||||
request_body_schema = getattr(action, "request_body_schema", None) or {}
|
||||
response_schema = getattr(action, "response_schema", None) or {}
|
||||
|
||||
parameter_locations: list[str] = []
|
||||
if isinstance(parameters_schema, dict):
|
||||
properties = parameters_schema.get("properties", {})
|
||||
if isinstance(properties, dict):
|
||||
for property_schema in properties.values():
|
||||
if not isinstance(property_schema, dict):
|
||||
continue
|
||||
location = property_schema.get("x-parameter-location")
|
||||
if isinstance(location, str) and location not in parameter_locations:
|
||||
parameter_locations.append(location)
|
||||
|
||||
request_content_type = (
|
||||
request_body_schema.get("x-content-type")
|
||||
if isinstance(request_body_schema, dict)
|
||||
else None
|
||||
)
|
||||
response_content_type = (
|
||||
response_schema.get("x-content-type")
|
||||
if isinstance(response_schema, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
return {
|
||||
"parameter_locations": parameter_locations,
|
||||
"request_content_types": [request_content_type]
|
||||
if isinstance(request_content_type, str)
|
||||
else [],
|
||||
"request_schema_type": request_body_schema.get("type")
|
||||
if isinstance(request_body_schema, dict)
|
||||
else None,
|
||||
"response_content_types": [response_content_type]
|
||||
if isinstance(response_content_type, str)
|
||||
else [],
|
||||
"response_schema_types": [response_schema.get("type")]
|
||||
if isinstance(response_schema, dict)
|
||||
and isinstance(response_schema.get("type"), str)
|
||||
else [],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_required_inputs(input_schema: dict[str, Any] | None) -> list[str]:
|
||||
if not isinstance(input_schema, dict):
|
||||
return []
|
||||
|
||||
required = input_schema.get("required")
|
||||
if isinstance(required, list):
|
||||
return [str(item) for item in required if isinstance(item, str) and item]
|
||||
|
||||
# Nested schemas: {"properties":{"parameters":{"required":[...]}, "request_body":{"required":[...]}}}
|
||||
nested_required: list[str] = []
|
||||
properties = input_schema.get("properties")
|
||||
if isinstance(properties, dict):
|
||||
for nested_name in ("parameters", "request_body"):
|
||||
nested_schema = properties.get(nested_name)
|
||||
if not isinstance(nested_schema, dict):
|
||||
continue
|
||||
nested = nested_schema.get("required")
|
||||
if isinstance(nested, list):
|
||||
for value in nested:
|
||||
if isinstance(value, str) and value and value not in nested_required:
|
||||
nested_required.append(value)
|
||||
return nested_required
|
||||
|
||||
@staticmethod
|
||||
def _extract_parameter_names_by_location(
|
||||
parameters_schema: dict[str, Any] | None,
|
||||
) -> dict[str, list[str]]:
|
||||
names_by_location: dict[str, list[str]] = {
|
||||
"path": [],
|
||||
"query": [],
|
||||
"header": [],
|
||||
"cookie": [],
|
||||
}
|
||||
if not isinstance(parameters_schema, dict):
|
||||
return names_by_location
|
||||
|
||||
properties = parameters_schema.get("properties")
|
||||
if not isinstance(properties, dict):
|
||||
return names_by_location
|
||||
|
||||
for name, schema in properties.items():
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
location = "query"
|
||||
if isinstance(schema, dict):
|
||||
location_raw = schema.get("x-parameter-location")
|
||||
if isinstance(location_raw, str) and location_raw in names_by_location:
|
||||
location = location_raw
|
||||
if name not in names_by_location[location]:
|
||||
names_by_location[location].append(name)
|
||||
return names_by_location
|
||||
|
||||
@staticmethod
|
||||
def _extract_schema_property_names(
|
||||
schema: dict[str, Any] | None,
|
||||
*,
|
||||
limit: int = 64,
|
||||
) -> list[str]:
|
||||
if not isinstance(schema, dict):
|
||||
return []
|
||||
|
||||
result: list[str] = []
|
||||
queue: list[dict[str, Any]] = [schema]
|
||||
seen: set[str] = set()
|
||||
|
||||
while queue and len(result) < limit:
|
||||
current = queue.pop(0)
|
||||
properties = current.get("properties")
|
||||
if isinstance(properties, dict):
|
||||
for key, value in properties.items():
|
||||
if isinstance(key, str) and key not in seen:
|
||||
seen.add(key)
|
||||
result.append(key)
|
||||
if len(result) >= limit:
|
||||
break
|
||||
if isinstance(value, dict):
|
||||
queue.append(value)
|
||||
items = current.get("items")
|
||||
if isinstance(items, dict):
|
||||
queue.append(items)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _extract_content_types_from_request(raw_spec: dict[str, Any]) -> list[str]:
|
||||
request_body = raw_spec.get("requestBody")
|
||||
if not isinstance(request_body, dict):
|
||||
return []
|
||||
content = request_body.get("content")
|
||||
if not isinstance(content, dict):
|
||||
return []
|
||||
return [str(content_type) for content_type in content.keys() if isinstance(content_type, str)]
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_hints(raw_spec: dict[str, Any]) -> tuple[list[str], list[str]]:
|
||||
responses = raw_spec.get("responses")
|
||||
if not isinstance(responses, dict):
|
||||
return [], []
|
||||
|
||||
response_status_codes: list[str] = []
|
||||
response_content_types: list[str] = []
|
||||
for status_code, response_payload in responses.items():
|
||||
status_value = str(status_code)
|
||||
if status_value not in response_status_codes:
|
||||
response_status_codes.append(status_value)
|
||||
if not isinstance(response_payload, dict):
|
||||
continue
|
||||
content = response_payload.get("content")
|
||||
if not isinstance(content, dict):
|
||||
continue
|
||||
for content_type in content.keys():
|
||||
if isinstance(content_type, str) and content_type not in response_content_types:
|
||||
response_content_types.append(content_type)
|
||||
|
||||
return response_status_codes, response_content_types
|
||||
@@ -0,0 +1,88 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from redis import asyncio as aioredis
|
||||
except ModuleNotFoundError:
|
||||
aioredis = None
|
||||
|
||||
from app.utils.ollama_client import chat_json, summarize_dialog_text
|
||||
|
||||
|
||||
class DialogMemoryService:
|
||||
def __init__(self) -> None:
|
||||
redis_host = os.getenv("REDIS_HOST", "localhost")
|
||||
redis_port = os.getenv("REDIS_PORT", "6379")
|
||||
self.redis_url = os.getenv("REDIS_URL", f"redis://{redis_host}:{redis_port}")
|
||||
self.ttl_seconds = int(os.getenv("DIALOG_TTL_SECONDS", "86400"))
|
||||
|
||||
async def get_context(self, dialog_id: str) -> tuple[list[dict[str, Any]], str | None]:
|
||||
redis = await self._get_redis()
|
||||
if redis is None:
|
||||
return [], None
|
||||
|
||||
messages_raw = await redis.get(self._messages_key(dialog_id))
|
||||
summary = await redis.get(self._summary_key(dialog_id))
|
||||
messages = self._decode_messages(messages_raw)
|
||||
return messages, summary
|
||||
|
||||
async def append_and_summarize(self, dialog_id: str, role: str, content: str) -> str | None:
|
||||
redis = await self._get_redis()
|
||||
if redis is None:
|
||||
return None
|
||||
|
||||
messages_key = self._messages_key(dialog_id)
|
||||
summary_key = self._summary_key(dialog_id)
|
||||
|
||||
current_messages = self._decode_messages(await redis.get(messages_key))
|
||||
current_messages.append({"role": role, "content": content})
|
||||
await redis.set(messages_key, json.dumps(current_messages, ensure_ascii=False), ex=self.ttl_seconds)
|
||||
|
||||
try:
|
||||
summary = await summarize_dialog_text(current_messages)
|
||||
except Exception:
|
||||
summary = None
|
||||
if summary is None:
|
||||
summary = self._fallback_summary(current_messages)
|
||||
await redis.set(summary_key, summary, ex=self.ttl_seconds)
|
||||
return summary
|
||||
|
||||
async def reset(self, dialog_id: str) -> None:
|
||||
redis = await self._get_redis()
|
||||
if redis is None:
|
||||
return
|
||||
await redis.delete(self._messages_key(dialog_id), self._summary_key(dialog_id))
|
||||
|
||||
async def _get_redis(self):
|
||||
if aioredis is None:
|
||||
return None
|
||||
try:
|
||||
redis = aioredis.from_url(self.redis_url, encoding="utf8", decode_responses=True)
|
||||
await redis.ping()
|
||||
return redis
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _messages_key(self, dialog_id: str) -> str:
|
||||
return f"dialog:{dialog_id}:messages"
|
||||
|
||||
def _summary_key(self, dialog_id: str) -> str:
|
||||
return f"dialog:{dialog_id}:summary"
|
||||
|
||||
def _decode_messages(self, payload: str | None) -> list[dict[str, Any]]:
|
||||
if not payload:
|
||||
return []
|
||||
try:
|
||||
parsed = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
if not isinstance(parsed, list):
|
||||
return []
|
||||
return [item for item in parsed if isinstance(item, dict)]
|
||||
|
||||
def _fallback_summary(self, messages: list[dict[str, Any]]) -> str:
|
||||
chunks = [str(item.get("content", "")) for item in messages[-4:]]
|
||||
return "\n".join(chunk for chunk in chunks if chunk)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,371 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from app.models import ActionIngestStatus, HttpMethod
|
||||
|
||||
|
||||
class OpenAPIService:
|
||||
SUPPORTED_METHODS = {method.value.lower(): method for method in HttpMethod}
|
||||
JSON_CONTENT_TYPES = ("application/json", "application/*+json")
|
||||
|
||||
@staticmethod
|
||||
def load_document(raw_bytes: bytes) -> dict[str, Any]:
|
||||
if not raw_bytes:
|
||||
raise ValueError("OpenAPI file is empty")
|
||||
|
||||
try:
|
||||
document = yaml.safe_load(raw_bytes.decode("utf-8"))
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ValueError("OpenAPI file must be UTF-8 encoded") from exc
|
||||
except yaml.YAMLError as exc:
|
||||
raise ValueError("OpenAPI file is not valid YAML or JSON") from exc
|
||||
|
||||
if not isinstance(document, dict):
|
||||
raise ValueError("OpenAPI root must be an object")
|
||||
|
||||
openapi_version = document.get("openapi")
|
||||
if not isinstance(openapi_version, str) or not openapi_version.startswith("3."):
|
||||
raise ValueError("Only OpenAPI 3.x documents are supported")
|
||||
|
||||
if not isinstance(document.get("paths"), dict) or not document["paths"]:
|
||||
raise ValueError("OpenAPI file must contain a non-empty paths section")
|
||||
|
||||
base_url = OpenAPIService._extract_base_url(document)
|
||||
if base_url is None:
|
||||
raise ValueError(
|
||||
"OpenAPI file must contain servers[0].url (base_url)"
|
||||
)
|
||||
|
||||
return document
|
||||
|
||||
@classmethod
|
||||
def extract_actions(
|
||||
cls,
|
||||
document: dict[str, Any],
|
||||
*,
|
||||
source_filename: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
return cls.extract_actions_with_failures(document, source_filename=source_filename)["succeeded"]
|
||||
|
||||
@classmethod
|
||||
def extract_actions_with_failures(
|
||||
cls,
|
||||
document: dict[str, Any],
|
||||
*,
|
||||
source_filename: str | None = None,
|
||||
) -> dict[str, list[dict[str, Any]]]:
|
||||
base_url = cls._extract_base_url(document)
|
||||
succeeded_actions: list[dict[str, Any]] = []
|
||||
failed_actions: list[dict[str, Any]] = []
|
||||
|
||||
for path, path_item in document.get("paths", {}).items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
|
||||
shared_parameters = path_item.get("parameters", [])
|
||||
|
||||
for method_name, operation in path_item.items():
|
||||
if method_name not in cls.SUPPORTED_METHODS:
|
||||
continue
|
||||
if not isinstance(operation, dict):
|
||||
failed_actions.append(
|
||||
cls._build_failed_action_payload(
|
||||
method_name=method_name,
|
||||
path=path,
|
||||
base_url=base_url,
|
||||
source_filename=source_filename,
|
||||
raw_spec=operation,
|
||||
error_message="Operation definition must be an object",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
succeeded_actions.append(
|
||||
cls._build_succeeded_action_payload(
|
||||
method_name=method_name,
|
||||
path=path,
|
||||
operation=operation,
|
||||
shared_parameters=shared_parameters,
|
||||
document=document,
|
||||
base_url=base_url,
|
||||
source_filename=source_filename,
|
||||
)
|
||||
)
|
||||
except ValueError as exc:
|
||||
failed_actions.append(
|
||||
cls._build_failed_action_payload(
|
||||
method_name=method_name,
|
||||
path=path,
|
||||
base_url=base_url,
|
||||
source_filename=source_filename,
|
||||
raw_spec=operation,
|
||||
error_message=str(exc),
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"succeeded": succeeded_actions,
|
||||
"failed": failed_actions,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _build_succeeded_action_payload(
|
||||
cls,
|
||||
*,
|
||||
method_name: str,
|
||||
path: str,
|
||||
operation: dict[str, Any],
|
||||
shared_parameters: list[Any] | None,
|
||||
document: dict[str, Any],
|
||||
base_url: str | None,
|
||||
source_filename: str | None,
|
||||
) -> dict[str, Any]:
|
||||
normalized_operation = cls._dereference(operation, document)
|
||||
parameters = cls._merge_parameters(shared_parameters, normalized_operation.get("parameters", []), document)
|
||||
|
||||
return {
|
||||
"operation_id": normalized_operation.get("operationId") or cls._build_operation_id(method_name, path),
|
||||
"method": cls.SUPPORTED_METHODS[method_name],
|
||||
"path": path,
|
||||
"base_url": base_url,
|
||||
"summary": normalized_operation.get("summary"),
|
||||
"description": normalized_operation.get("description"),
|
||||
"tags": normalized_operation.get("tags"),
|
||||
"parameters_schema": cls._build_parameters_schema(parameters, document),
|
||||
"request_body_schema": cls._extract_request_body_schema(normalized_operation, document),
|
||||
"response_schema": cls._extract_response_schema(normalized_operation, document),
|
||||
"source_filename": source_filename,
|
||||
"raw_spec": normalized_operation,
|
||||
"ingest_status": ActionIngestStatus.SUCCEEDED,
|
||||
"ingest_error": None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _build_failed_action_payload(
|
||||
cls,
|
||||
*,
|
||||
method_name: str,
|
||||
path: str,
|
||||
base_url: str | None,
|
||||
source_filename: str | None,
|
||||
raw_spec: Any,
|
||||
error_message: str,
|
||||
) -> dict[str, Any]:
|
||||
operation = raw_spec if isinstance(raw_spec, dict) else {}
|
||||
|
||||
return {
|
||||
"operation_id": operation.get("operationId") or cls._build_operation_id(method_name, path),
|
||||
"method": cls.SUPPORTED_METHODS[method_name],
|
||||
"path": path,
|
||||
"base_url": base_url,
|
||||
"summary": operation.get("summary"),
|
||||
"description": operation.get("description"),
|
||||
"tags": operation.get("tags"),
|
||||
"parameters_schema": None,
|
||||
"request_body_schema": None,
|
||||
"response_schema": None,
|
||||
"source_filename": source_filename,
|
||||
"raw_spec": operation or None,
|
||||
"ingest_status": ActionIngestStatus.FAILED,
|
||||
"ingest_error": error_message,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_base_url(document: dict[str, Any]) -> str | None:
|
||||
servers = document.get("servers")
|
||||
if isinstance(servers, list) and servers:
|
||||
first_server = servers[0]
|
||||
if isinstance(first_server, dict):
|
||||
url = first_server.get("url")
|
||||
if isinstance(url, str):
|
||||
normalized_url = url.strip()
|
||||
if normalized_url:
|
||||
return normalized_url
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _merge_parameters(
|
||||
cls,
|
||||
path_parameters: list[Any] | None,
|
||||
operation_parameters: list[Any] | None,
|
||||
document: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
merged: dict[tuple[str | None, str | None], dict[str, Any]] = {}
|
||||
|
||||
for raw_parameter in (path_parameters or []) + (operation_parameters or []):
|
||||
parameter = cls._dereference(raw_parameter, document)
|
||||
if not isinstance(parameter, dict):
|
||||
continue
|
||||
key = (parameter.get("name"), parameter.get("in"))
|
||||
merged[key] = parameter
|
||||
|
||||
return list(merged.values())
|
||||
|
||||
@classmethod
|
||||
def _build_parameters_schema(
|
||||
cls,
|
||||
parameters: list[dict[str, Any]],
|
||||
document: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
if not parameters:
|
||||
return None
|
||||
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
|
||||
for parameter in parameters:
|
||||
name = parameter.get("name")
|
||||
if not name:
|
||||
continue
|
||||
if parameter.get("in") not in {"query", "path", "header", "cookie"}:
|
||||
continue
|
||||
|
||||
schema = parameter.get("schema")
|
||||
if schema is None:
|
||||
schema = cls._extract_schema_from_content(parameter.get("content"), document)
|
||||
else:
|
||||
schema = cls._dereference(schema, document)
|
||||
|
||||
property_schema = schema if isinstance(schema, dict) else {"type": "string"}
|
||||
property_schema = {
|
||||
**property_schema,
|
||||
"x-parameter-location": parameter.get("in"),
|
||||
}
|
||||
|
||||
if parameter.get("description"):
|
||||
property_schema["description"] = parameter["description"]
|
||||
|
||||
properties[name] = property_schema
|
||||
|
||||
if parameter.get("required"):
|
||||
required.append(name)
|
||||
|
||||
if not properties:
|
||||
return None
|
||||
|
||||
schema: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
if required:
|
||||
schema["required"] = required
|
||||
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def _extract_request_body_schema(
|
||||
cls,
|
||||
operation: dict[str, Any],
|
||||
document: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
request_body = operation.get("requestBody")
|
||||
if not isinstance(request_body, dict):
|
||||
return None
|
||||
request_body = cls._dereference(request_body, document)
|
||||
schema = cls._extract_schema_from_content(request_body.get("content"), document)
|
||||
if not isinstance(schema, dict):
|
||||
return None
|
||||
|
||||
if request_body.get("required"):
|
||||
schema = {**schema, "x-required": True}
|
||||
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def _extract_response_schema(
|
||||
cls,
|
||||
operation: dict[str, Any],
|
||||
document: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
responses = operation.get("responses")
|
||||
if not isinstance(responses, dict):
|
||||
return None
|
||||
|
||||
for status_code, response in responses.items():
|
||||
if not str(status_code).startswith("2"):
|
||||
continue
|
||||
|
||||
normalized_response = cls._dereference(response, document)
|
||||
if not isinstance(normalized_response, dict):
|
||||
continue
|
||||
|
||||
schema = cls._extract_schema_from_content(normalized_response.get("content"), document)
|
||||
if isinstance(schema, dict):
|
||||
return schema
|
||||
|
||||
if normalized_response.get("description"):
|
||||
return {"description": normalized_response["description"]}
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_schema_from_content(cls, content: Any, document: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if not isinstance(content, dict):
|
||||
return None
|
||||
|
||||
preferred_content_type = next((content_type for content_type in cls.JSON_CONTENT_TYPES if content_type in content), None)
|
||||
items = []
|
||||
if preferred_content_type:
|
||||
items.append((preferred_content_type, content[preferred_content_type]))
|
||||
items.extend((content_type, value) for content_type, value in content.items() if content_type != preferred_content_type)
|
||||
|
||||
for content_type, value in items:
|
||||
if not isinstance(value, dict):
|
||||
continue
|
||||
schema = value.get("schema")
|
||||
if not isinstance(schema, dict):
|
||||
continue
|
||||
|
||||
normalized_schema = cls._dereference(schema, document)
|
||||
if isinstance(normalized_schema, dict):
|
||||
return {
|
||||
**normalized_schema,
|
||||
"x-content-type": content_type,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _dereference(cls, value: Any, document: dict[str, Any]) -> Any:
|
||||
if isinstance(value, list):
|
||||
return [cls._dereference(item, document) for item in value]
|
||||
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
if "$ref" in value:
|
||||
resolved = cls._resolve_ref(value["$ref"], document)
|
||||
merged = cls._dereference(resolved, document)
|
||||
if not isinstance(merged, dict):
|
||||
return merged
|
||||
|
||||
sibling_fields = {key: cls._dereference(item, document) for key, item in value.items() if key != "$ref"}
|
||||
return {**merged, **sibling_fields}
|
||||
|
||||
return {key: cls._dereference(item, document) for key, item in value.items()}
|
||||
|
||||
@staticmethod
|
||||
def _resolve_ref(ref: str, document: dict[str, Any]) -> Any:
|
||||
if not ref.startswith("#/"):
|
||||
raise ValueError(f"Only local $ref values are supported, got: {ref}")
|
||||
|
||||
current: Any = document
|
||||
for part in ref[2:].split("/"):
|
||||
token = part.replace("~1", "/").replace("~0", "~")
|
||||
if not isinstance(current, dict) or token not in current:
|
||||
raise ValueError(f"Could not resolve OpenAPI reference: {ref}")
|
||||
current = current[token]
|
||||
|
||||
return current
|
||||
|
||||
@staticmethod
|
||||
def _build_operation_id(method_name: str, path: str) -> str:
|
||||
normalized_path = re.sub(r"[{}]", "", path).strip("/")
|
||||
normalized_path = re.sub(r"[^a-zA-Z0-9/]+", "_", normalized_path)
|
||||
normalized_path = normalized_path.replace("/", "_") or "root"
|
||||
return f"{method_name.lower()}_{normalized_path.lower()}"
|
||||
@@ -0,0 +1,176 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import DialogMessageRole, PipelineDialog, PipelineDialogMessage
|
||||
|
||||
|
||||
class DialogAccessError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PipelineDialogService:
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self.session = session
|
||||
|
||||
async def list_dialogs(
|
||||
self,
|
||||
*,
|
||||
user_id: UUID,
|
||||
limit: int,
|
||||
offset: int,
|
||||
) -> list[PipelineDialog]:
|
||||
query = (
|
||||
select(PipelineDialog)
|
||||
.where(PipelineDialog.user_id == user_id)
|
||||
.order_by(PipelineDialog.updated_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await self.session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_history(
|
||||
self,
|
||||
*,
|
||||
dialog_id: UUID,
|
||||
user_id: UUID,
|
||||
limit: int,
|
||||
offset: int,
|
||||
) -> tuple[PipelineDialog, list[PipelineDialogMessage]]:
|
||||
dialog = await self._get_dialog_owned_by_user(dialog_id=dialog_id, user_id=user_id)
|
||||
|
||||
query = (
|
||||
select(PipelineDialogMessage)
|
||||
.where(PipelineDialogMessage.dialog_id == dialog.id)
|
||||
.order_by(PipelineDialogMessage.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await self.session.execute(query)
|
||||
messages_desc = list(result.scalars().all())
|
||||
return dialog, list(reversed(messages_desc))
|
||||
|
||||
async def get_dialog(
|
||||
self,
|
||||
*,
|
||||
dialog_id: UUID,
|
||||
user_id: UUID,
|
||||
) -> PipelineDialog:
|
||||
return await self._get_dialog_owned_by_user(dialog_id=dialog_id, user_id=user_id)
|
||||
|
||||
async def append_user_message(
|
||||
self,
|
||||
*,
|
||||
dialog_id: UUID,
|
||||
user_id: UUID,
|
||||
content: str,
|
||||
) -> PipelineDialogMessage:
|
||||
return await self._append_message(
|
||||
dialog_id=dialog_id,
|
||||
user_id=user_id,
|
||||
role=DialogMessageRole.USER,
|
||||
content=content,
|
||||
assistant_payload=None,
|
||||
create_dialog_if_missing=True,
|
||||
)
|
||||
|
||||
async def append_assistant_message(
|
||||
self,
|
||||
*,
|
||||
dialog_id: UUID,
|
||||
user_id: UUID,
|
||||
content: str,
|
||||
assistant_payload: dict[str, Any],
|
||||
) -> PipelineDialogMessage:
|
||||
return await self._append_message(
|
||||
dialog_id=dialog_id,
|
||||
user_id=user_id,
|
||||
role=DialogMessageRole.ASSISTANT,
|
||||
content=content,
|
||||
assistant_payload=assistant_payload,
|
||||
create_dialog_if_missing=False,
|
||||
)
|
||||
|
||||
async def _append_message(
|
||||
self,
|
||||
*,
|
||||
dialog_id: UUID,
|
||||
user_id: UUID,
|
||||
role: DialogMessageRole,
|
||||
content: str,
|
||||
assistant_payload: dict[str, Any] | None,
|
||||
create_dialog_if_missing: bool,
|
||||
) -> PipelineDialogMessage:
|
||||
dialog = await self.session.get(PipelineDialog, dialog_id)
|
||||
if dialog is None:
|
||||
if not create_dialog_if_missing:
|
||||
raise DialogAccessError("Dialog not found")
|
||||
dialog = PipelineDialog(
|
||||
id=dialog_id,
|
||||
user_id=user_id,
|
||||
title=self._build_title(content),
|
||||
)
|
||||
self.session.add(dialog)
|
||||
await self.session.flush()
|
||||
elif dialog.user_id != user_id:
|
||||
raise DialogAccessError("Dialog access denied")
|
||||
|
||||
if role == DialogMessageRole.USER and not dialog.title:
|
||||
dialog.title = self._build_title(content)
|
||||
|
||||
message = PipelineDialogMessage(
|
||||
dialog_id=dialog.id,
|
||||
role=role,
|
||||
content=content,
|
||||
assistant_payload=assistant_payload,
|
||||
)
|
||||
self.session.add(message)
|
||||
|
||||
dialog.last_message_preview = self._build_preview(content)
|
||||
if role == DialogMessageRole.ASSISTANT and assistant_payload:
|
||||
status = assistant_payload.get("status")
|
||||
if isinstance(status, str):
|
||||
dialog.last_status = status
|
||||
pipeline_id = self._parse_uuid(assistant_payload.get("pipeline_id"))
|
||||
if pipeline_id is not None:
|
||||
# Preserve the last valid graph reference for non-ready statuses.
|
||||
dialog.last_pipeline_id = pipeline_id
|
||||
|
||||
await self.session.commit()
|
||||
return message
|
||||
|
||||
async def _get_dialog_owned_by_user(
|
||||
self,
|
||||
*,
|
||||
dialog_id: UUID,
|
||||
user_id: UUID,
|
||||
) -> PipelineDialog:
|
||||
dialog = await self.session.get(PipelineDialog, dialog_id)
|
||||
if dialog is None:
|
||||
raise DialogAccessError("Dialog not found")
|
||||
if dialog.user_id != user_id:
|
||||
raise DialogAccessError("Dialog access denied")
|
||||
return dialog
|
||||
|
||||
def _build_title(self, content: str) -> str:
|
||||
text = (content or "").strip().replace("\n", " ")
|
||||
return (text[:120] or "Pipeline dialog")
|
||||
|
||||
def _build_preview(self, content: str) -> str:
|
||||
text = (content or "").strip().replace("\n", " ")
|
||||
return text[:280]
|
||||
|
||||
def _parse_uuid(self, value: Any) -> UUID | None:
|
||||
if isinstance(value, UUID):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return UUID(value)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,491 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, NamedTuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import Action, Capability
|
||||
from app.models.capability import CapabilityType
|
||||
|
||||
|
||||
class SelectedCapability(NamedTuple):
|
||||
capability: Capability
|
||||
score: float
|
||||
confidence_tier: str = "high"
|
||||
|
||||
|
||||
class SemanticSelectionService:
|
||||
HIGH_CONFIDENCE_THRESHOLD = 0.45
|
||||
MEDIUM_CONFIDENCE_THRESHOLD = 0.30
|
||||
LOW_MARGIN_THRESHOLD = 0.05
|
||||
CRM_TOKENS = {
|
||||
"crm",
|
||||
"segment",
|
||||
"segments",
|
||||
"audience",
|
||||
"campaign",
|
||||
"campaigns",
|
||||
"mailing",
|
||||
"newsletter",
|
||||
"lead",
|
||||
"leads",
|
||||
"retention",
|
||||
"cohort",
|
||||
"churn",
|
||||
"conversion",
|
||||
"promo",
|
||||
"offer",
|
||||
"offers",
|
||||
"email",
|
||||
"emails",
|
||||
"push",
|
||||
"sale",
|
||||
"sales",
|
||||
"сегмент",
|
||||
"сегменты",
|
||||
"аудитория",
|
||||
"кампания",
|
||||
"кампании",
|
||||
"рассылка",
|
||||
"лид",
|
||||
"лиды",
|
||||
"ретеншн",
|
||||
"конверсия",
|
||||
"оффер",
|
||||
"офферы",
|
||||
"пуш",
|
||||
"продажи",
|
||||
"клиент",
|
||||
"клиенты",
|
||||
}
|
||||
GENERIC_TOKENS = {
|
||||
"get",
|
||||
"list",
|
||||
"create",
|
||||
"update",
|
||||
"delete",
|
||||
"call",
|
||||
"data",
|
||||
"info",
|
||||
"items",
|
||||
"resource",
|
||||
"resources",
|
||||
"service",
|
||||
"api",
|
||||
"handle",
|
||||
"handler",
|
||||
"manage",
|
||||
"process",
|
||||
"method",
|
||||
"action",
|
||||
"fetch",
|
||||
"general",
|
||||
"common",
|
||||
"получить",
|
||||
"список",
|
||||
"создать",
|
||||
"обновить",
|
||||
"удалить",
|
||||
"данные",
|
||||
"инфо",
|
||||
"ресурс",
|
||||
"сервис",
|
||||
"метод",
|
||||
"действие",
|
||||
"общее",
|
||||
}
|
||||
_STOPWORDS = {
|
||||
"and",
|
||||
"the",
|
||||
"for",
|
||||
"with",
|
||||
"from",
|
||||
"into",
|
||||
"that",
|
||||
"this",
|
||||
"что",
|
||||
"это",
|
||||
"как",
|
||||
"для",
|
||||
"или",
|
||||
"при",
|
||||
"про",
|
||||
"надо",
|
||||
"нужно",
|
||||
"хочу",
|
||||
"build",
|
||||
"pipeline",
|
||||
"workflow",
|
||||
"scenario",
|
||||
"automation",
|
||||
"пайплайн",
|
||||
"сценарий",
|
||||
"автоматизация",
|
||||
"построй",
|
||||
"собери",
|
||||
}
|
||||
_ALIAS_EXPANSIONS = {
|
||||
"польз": {"user", "users", "client", "clients", "пользователь", "пользователи"},
|
||||
"клиент": {"client", "clients", "user", "users", "клиент", "клиенты"},
|
||||
"юзер": {"user", "users", "пользователь", "пользователи"},
|
||||
"получ": {"get", "fetch", "list", "retrieve", "получить", "список"},
|
||||
"спис": {"list", "get", "fetch", "список", "получить"},
|
||||
"созд": {"create", "add", "post", "создать"},
|
||||
"обнов": {"update", "patch", "put", "обновить"},
|
||||
"удал": {"delete", "remove", "del", "удалить"},
|
||||
"рассыл": {"mailing", "newsletter", "broadcast", "email", "рассылка"},
|
||||
"сегмент": {"segment", "segments", "сегмент", "сегменты"},
|
||||
"лид": {"lead", "leads", "лид", "лиды"},
|
||||
"отчет": {"report", "analytics", "отчет", "отчёт"},
|
||||
"отчёт": {"report", "analytics", "отчет", "отчёт"},
|
||||
"user": {"пользователь", "пользователи", "user", "users"},
|
||||
"users": {"пользователь", "пользователи", "user", "users"},
|
||||
"get": {"получить", "список", "get", "fetch", "list"},
|
||||
"fetch": {"получить", "список", "get", "fetch", "list"},
|
||||
"list": {"получить", "список", "get", "fetch", "list"},
|
||||
}
|
||||
|
||||
async def select_capabilities(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_query: str,
|
||||
owner_user_id: UUID | None = None,
|
||||
limit: int = 10,
|
||||
) -> list[SelectedCapability]:
|
||||
query_tokens = self._tokenize(user_query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
query = select(Capability).order_by(Capability.created_at.asc())
|
||||
if owner_user_id is not None:
|
||||
# User-scoped with legacy compatibility:
|
||||
# some old capabilities may have user_id=NULL while their source action has owner.
|
||||
query = query.outerjoin(Action, Capability.action_id == Action.id).where(
|
||||
or_(
|
||||
Capability.user_id == owner_user_id,
|
||||
and_(
|
||||
Capability.user_id.is_(None),
|
||||
Action.user_id == owner_user_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
query = query.limit(200)
|
||||
result = await session.execute(query)
|
||||
capabilities = list(result.scalars().all())
|
||||
|
||||
executable_capabilities = [
|
||||
capability
|
||||
for capability in capabilities
|
||||
if self._is_executable_capability(capability)
|
||||
]
|
||||
candidates = executable_capabilities
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
query_tokens_expanded = self._expand_tokens(query_tokens)
|
||||
ranked: list[SelectedCapability] = []
|
||||
for capability in candidates:
|
||||
score = self._score_capability(query_tokens, query_tokens_expanded, capability)
|
||||
if score <= 0:
|
||||
continue
|
||||
ranked.append(SelectedCapability(capability=capability, score=score))
|
||||
|
||||
ranked.sort(key=lambda item: item.score, reverse=True)
|
||||
if not ranked:
|
||||
if candidates:
|
||||
# Fallback: keep generation moving even when lexical matching is weak.
|
||||
return [
|
||||
SelectedCapability(
|
||||
capability=capability,
|
||||
score=0.01,
|
||||
confidence_tier="low",
|
||||
)
|
||||
for capability in candidates[:limit]
|
||||
]
|
||||
return []
|
||||
|
||||
top_score = ranked[0].score
|
||||
second_score = ranked[1].score if len(ranked) > 1 else 0.0
|
||||
margin = top_score - second_score
|
||||
confidence_tier = self._resolve_confidence_tier(top_score, margin)
|
||||
|
||||
return [
|
||||
SelectedCapability(
|
||||
capability=item.capability,
|
||||
score=item.score,
|
||||
confidence_tier=confidence_tier,
|
||||
)
|
||||
for item in ranked[:limit]
|
||||
]
|
||||
|
||||
def _score_capability(
|
||||
self,
|
||||
query_tokens: set[str],
|
||||
query_tokens_expanded: set[str],
|
||||
capability: Capability,
|
||||
) -> float:
|
||||
name = str(getattr(capability, "name", "") or "")
|
||||
description = str(getattr(capability, "description", "") or "")
|
||||
name_tokens = self._tokenize(name)
|
||||
description_tokens = self._tokenize(description)
|
||||
context_tokens = self._extract_context_tokens(capability)
|
||||
recipe_tokens = self._extract_recipe_tokens(capability)
|
||||
combined_tokens = name_tokens | description_tokens | context_tokens | recipe_tokens
|
||||
if not combined_tokens:
|
||||
return 0.0
|
||||
|
||||
combined_tokens_expanded = self._expand_tokens(combined_tokens)
|
||||
overlap = query_tokens_expanded & combined_tokens_expanded
|
||||
if not overlap:
|
||||
return 0.0
|
||||
|
||||
overlap_ratio = len(overlap) / len(query_tokens_expanded)
|
||||
name_tokens_expanded = self._expand_tokens(name_tokens)
|
||||
name_ratio = len(query_tokens_expanded & name_tokens_expanded) / len(query_tokens_expanded)
|
||||
exact_bonus = 0.22 if query_tokens_expanded <= combined_tokens_expanded else 0.0
|
||||
context_ratio = 0.0
|
||||
context_bonus = 0.0
|
||||
if context_tokens:
|
||||
context_tokens_expanded = self._expand_tokens(context_tokens)
|
||||
context_overlap = query_tokens_expanded & context_tokens_expanded
|
||||
context_ratio = len(context_overlap) / len(query_tokens_expanded)
|
||||
context_bonus = min(0.16, len(context_overlap) * 0.03)
|
||||
|
||||
generic_expanded = self._expand_tokens(self.GENERIC_TOKENS)
|
||||
entity_overlap = overlap - generic_expanded
|
||||
entity_bonus = min(0.18, len(entity_overlap) * 0.06) if entity_overlap else 0.0
|
||||
|
||||
query_crm_tokens = query_tokens_expanded & self.CRM_TOKENS
|
||||
capability_crm_tokens = combined_tokens_expanded & self.CRM_TOKENS
|
||||
crm_bonus = 0.0
|
||||
if query_crm_tokens and capability_crm_tokens:
|
||||
crm_overlap = len(query_crm_tokens & capability_crm_tokens)
|
||||
crm_bonus = 0.12 + min(0.14, crm_overlap * 0.04)
|
||||
|
||||
generic_penalty = self._generic_capability_penalty(combined_tokens)
|
||||
|
||||
return (
|
||||
max(overlap_ratio, name_ratio * 1.12, context_ratio * 0.95)
|
||||
+ exact_bonus
|
||||
+ context_bonus
|
||||
+ entity_bonus
|
||||
+ crm_bonus
|
||||
- generic_penalty
|
||||
)
|
||||
|
||||
def _extract_context_tokens(self, capability: Capability) -> set[str]:
|
||||
llm_payload = getattr(capability, "llm_payload", None)
|
||||
if not isinstance(llm_payload, dict):
|
||||
return set()
|
||||
|
||||
chunks: list[str] = []
|
||||
for key in (
|
||||
"action_context_brief",
|
||||
"openapi_hints",
|
||||
"action_context",
|
||||
"recipe_summary",
|
||||
"composite_context",
|
||||
):
|
||||
value = llm_payload.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
self._collect_text_chunks(value=value, chunks=chunks, depth=0, max_depth=4)
|
||||
|
||||
tokens: set[str] = set()
|
||||
for chunk in chunks[:120]:
|
||||
tokens.update(self._tokenize(chunk))
|
||||
return tokens
|
||||
|
||||
def _extract_recipe_tokens(self, capability: Capability) -> set[str]:
|
||||
recipe = getattr(capability, "recipe", None)
|
||||
if not isinstance(recipe, dict):
|
||||
return set()
|
||||
steps = recipe.get("steps")
|
||||
if not isinstance(steps, list):
|
||||
return set()
|
||||
|
||||
chunks: list[str] = []
|
||||
for raw_step in steps[:30]:
|
||||
if not isinstance(raw_step, dict):
|
||||
continue
|
||||
inputs = raw_step.get("inputs")
|
||||
if not isinstance(inputs, dict):
|
||||
continue
|
||||
for key, value in inputs.items():
|
||||
if isinstance(key, str):
|
||||
chunks.append(key)
|
||||
if isinstance(value, str):
|
||||
chunks.append(value)
|
||||
|
||||
tokens: set[str] = set()
|
||||
for chunk in chunks:
|
||||
tokens.update(self._tokenize(chunk))
|
||||
return tokens
|
||||
|
||||
def _collect_text_chunks(
|
||||
self,
|
||||
*,
|
||||
value: object,
|
||||
chunks: list[str],
|
||||
depth: int,
|
||||
max_depth: int,
|
||||
) -> None:
|
||||
if depth > max_depth or len(chunks) >= 120:
|
||||
return
|
||||
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if stripped:
|
||||
chunks.append(stripped)
|
||||
return
|
||||
|
||||
if isinstance(value, dict):
|
||||
preferred_keys = {
|
||||
"operation_id",
|
||||
"method",
|
||||
"path",
|
||||
"base_url",
|
||||
"summary",
|
||||
"description",
|
||||
"tags",
|
||||
"source_filename",
|
||||
"required_inputs",
|
||||
"request_content_types",
|
||||
"response_content_types",
|
||||
"response_status_codes",
|
||||
"security_requirements",
|
||||
"parameter_names_by_location",
|
||||
"path_segments",
|
||||
"input_signals",
|
||||
"output_signals",
|
||||
}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
if key not in preferred_keys:
|
||||
continue
|
||||
chunks.append(key)
|
||||
self._collect_text_chunks(
|
||||
value=item,
|
||||
chunks=chunks,
|
||||
depth=depth + 1,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(value, list):
|
||||
for item in value[:30]:
|
||||
self._collect_text_chunks(
|
||||
value=item,
|
||||
chunks=chunks,
|
||||
depth=depth + 1,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
||||
def _resolve_confidence_tier(self, top_score: float, margin: float) -> str:
|
||||
if margin < self.LOW_MARGIN_THRESHOLD:
|
||||
return "low"
|
||||
if top_score >= self.HIGH_CONFIDENCE_THRESHOLD:
|
||||
return "high"
|
||||
if top_score >= self.MEDIUM_CONFIDENCE_THRESHOLD:
|
||||
return "medium"
|
||||
return "low"
|
||||
|
||||
def _generic_capability_penalty(self, tokens: set[str]) -> float:
|
||||
if not tokens:
|
||||
return 0.0
|
||||
generic_share = len(tokens & self.GENERIC_TOKENS) / len(tokens)
|
||||
if generic_share >= 0.65:
|
||||
return 0.14
|
||||
if generic_share >= 0.5:
|
||||
return 0.09
|
||||
if generic_share >= 0.35:
|
||||
return 0.04
|
||||
return 0.0
|
||||
|
||||
def _tokenize(self, value: str) -> set[str]:
|
||||
tokens = set(re.findall(r"[a-zA-Zа-яА-Я0-9]+", value.lower()))
|
||||
return {
|
||||
token
|
||||
for token in tokens
|
||||
if len(token) >= 3 and token not in self._STOPWORDS
|
||||
}
|
||||
|
||||
def _is_executable_capability(self, capability: Capability) -> bool:
|
||||
cap_type = self._capability_type_value(capability)
|
||||
if cap_type == CapabilityType.ATOMIC.value:
|
||||
return getattr(capability, "action_id", None) is not None
|
||||
if cap_type == CapabilityType.COMPOSITE.value:
|
||||
return self._recipe_is_executable(getattr(capability, "recipe", None))
|
||||
return False
|
||||
|
||||
def _recipe_is_executable(self, recipe: Any) -> bool:
|
||||
if not isinstance(recipe, dict):
|
||||
return False
|
||||
if recipe.get("version") != 1:
|
||||
return False
|
||||
steps = recipe.get("steps")
|
||||
return isinstance(steps, list) and bool(steps)
|
||||
|
||||
def _capability_type_value(self, capability: Capability) -> str:
|
||||
raw = getattr(capability, "type", None)
|
||||
if isinstance(raw, CapabilityType):
|
||||
return raw.value
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if hasattr(raw, "value"):
|
||||
return str(raw.value)
|
||||
return CapabilityType.ATOMIC.value
|
||||
|
||||
def _expand_tokens(self, tokens: set[str]) -> set[str]:
|
||||
expanded: set[str] = set()
|
||||
for token in tokens:
|
||||
expanded.add(token)
|
||||
normalized_variants = self._normalized_variants(token)
|
||||
expanded.update(normalized_variants)
|
||||
for variant in normalized_variants | {token}:
|
||||
for key, aliases in self._ALIAS_EXPANSIONS.items():
|
||||
if variant == key or variant.startswith(key):
|
||||
expanded.update(aliases)
|
||||
return expanded
|
||||
|
||||
def _normalized_variants(self, token: str) -> set[str]:
|
||||
variants = {token}
|
||||
if len(token) >= 5:
|
||||
for suffix in (
|
||||
"иями",
|
||||
"ями",
|
||||
"ами",
|
||||
"ов",
|
||||
"ев",
|
||||
"ей",
|
||||
"ам",
|
||||
"ям",
|
||||
"ах",
|
||||
"ях",
|
||||
"ые",
|
||||
"ий",
|
||||
"ый",
|
||||
"ая",
|
||||
"ое",
|
||||
"ой",
|
||||
"а",
|
||||
"я",
|
||||
"ы",
|
||||
"и",
|
||||
"у",
|
||||
"ю",
|
||||
"е",
|
||||
"о",
|
||||
):
|
||||
if token.endswith(suffix) and len(token) > len(suffix) + 2:
|
||||
variants.add(token[: -len(suffix)])
|
||||
|
||||
if token.endswith("ies") and len(token) > 4:
|
||||
variants.add(token[:-3] + "y")
|
||||
if token.endswith("s") and len(token) > 3:
|
||||
variants.add(token[:-1])
|
||||
return variants
|
||||
Reference in New Issue
Block a user