upload
This commit is contained in:
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from app.utils.log_context import get_log_context
|
||||
|
||||
|
||||
business_logger = logging.getLogger("app.business")
|
||||
EVENT_SCHEMA_VERSION = "1.0"
|
||||
SERVICE_NAME = os.getenv("APP_SERVICE_NAME", "backend-api")
|
||||
|
||||
|
||||
def _derive_event_group(event: str) -> tuple[str, str | None]:
|
||||
normalized = (event or "").strip().lower()
|
||||
|
||||
if normalized.startswith("auth_"):
|
||||
return "auth", None
|
||||
|
||||
if normalized.startswith("action_") or normalized.startswith("actions_"):
|
||||
return "actions", None
|
||||
|
||||
if (
|
||||
normalized.startswith("capability_")
|
||||
or normalized.startswith("capabilities_")
|
||||
or normalized.startswith("composite_capability_")
|
||||
):
|
||||
return "capabilities", None
|
||||
|
||||
if normalized.startswith("pipeline_prompt_"):
|
||||
return "pipelines", "prompt"
|
||||
if normalized.startswith("pipeline_run_"):
|
||||
return "pipelines", "run"
|
||||
if normalized.startswith("pipeline_dialog_"):
|
||||
return "pipelines", "dialog"
|
||||
if normalized.startswith("pipeline_") or normalized.startswith("pipelines_"):
|
||||
return "pipelines", None
|
||||
|
||||
if normalized.startswith("execution_run_"):
|
||||
return "executions", "run"
|
||||
if normalized.startswith("execution_step_"):
|
||||
return "executions", "step"
|
||||
if normalized.startswith("execution_") or normalized.startswith("executions_"):
|
||||
return "executions", None
|
||||
|
||||
if normalized.startswith("user_") or normalized.startswith("users_"):
|
||||
return "users", None
|
||||
|
||||
return "other", None
|
||||
|
||||
|
||||
def _derive_event_outcome(event: str) -> str:
|
||||
normalized = (event or "").strip().lower()
|
||||
for suffix, outcome in (
|
||||
("_succeeded", "success"),
|
||||
("_created", "success"),
|
||||
("_updated", "success"),
|
||||
("_deleted", "success"),
|
||||
("_processed", "success"),
|
||||
("_finished", "success"),
|
||||
("_failed", "failure"),
|
||||
("_rejected", "failure"),
|
||||
("_blocked", "failure"),
|
||||
("_started", "progress"),
|
||||
("_queued", "progress"),
|
||||
("_received", "progress"),
|
||||
("_listed", "read"),
|
||||
("_fetched", "read"),
|
||||
("_viewed", "read"),
|
||||
):
|
||||
if normalized.endswith(suffix):
|
||||
return outcome
|
||||
return "unknown"
|
||||
|
||||
|
||||
def log_business_event(event: str, **fields: Any) -> None:
|
||||
safe_fields: dict[str, Any] = {
|
||||
"event": event,
|
||||
"event_schema_version": EVENT_SCHEMA_VERSION,
|
||||
"service_name": SERVICE_NAME,
|
||||
}
|
||||
event_group, event_subgroup = _derive_event_group(event)
|
||||
event_outcome = _derive_event_outcome(event)
|
||||
|
||||
if "event_group" not in fields:
|
||||
safe_fields["event_group"] = event_group
|
||||
if event_subgroup is not None and "event_subgroup" not in fields:
|
||||
safe_fields["event_subgroup"] = event_subgroup
|
||||
if "event_outcome" not in fields:
|
||||
safe_fields["event_outcome"] = event_outcome
|
||||
|
||||
for key, value in get_log_context().items():
|
||||
if key not in fields:
|
||||
safe_fields[key] = value
|
||||
|
||||
for key, value in fields.items():
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
safe_fields[key] = value
|
||||
else:
|
||||
safe_fields[key] = str(value)
|
||||
|
||||
business_logger.info(event, extra=safe_fields)
|
||||
@@ -0,0 +1,124 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from fastapi import Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
||||
trace_id = getattr(request.state, "traceId", str(uuid.uuid4()))
|
||||
is_json_error = any(e.get("type") in ("json_invalid", "json_decode", "value_error.jsondecode") for e in exc.errors())
|
||||
|
||||
if is_json_error:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
"code": "BAD_REQUEST",
|
||||
"message": "Невалидный JSON",
|
||||
"traceId": trace_id,
|
||||
"timestamp": now_iso(),
|
||||
"path": request.url.path,
|
||||
"details": {"hint": "Проверьте запятые/кавычки"},
|
||||
},
|
||||
)
|
||||
|
||||
field_errors: list[dict[str, Any]] = []
|
||||
for err in exc.errors():
|
||||
loc = [str(x) for x in err.get("loc", []) if x != "body"]
|
||||
field_name = ".".join(loc) if loc else "unknown"
|
||||
|
||||
msg = err.get("msg", "invalid")
|
||||
if msg.startswith("Value error, "):
|
||||
msg = msg.replace("Value error, ", "")
|
||||
|
||||
field_errors.append({
|
||||
"field": field_name,
|
||||
"issue": msg,
|
||||
"rejectedValue": err.get("input", None),
|
||||
})
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content={
|
||||
"code": "VALIDATION_FAILED",
|
||||
"message": "Некоторые поля не прошли валидацию",
|
||||
"traceId": trace_id,
|
||||
"timestamp": now_iso(),
|
||||
"path": request.url.path,
|
||||
"fieldErrors": field_errors,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
trace_id = getattr(request.state, "traceId", str(uuid.uuid4()))
|
||||
|
||||
message = str(exc.detail)
|
||||
details = None
|
||||
|
||||
if isinstance(exc.detail, dict):
|
||||
message = exc.detail.get("message", str(exc.detail))
|
||||
details_data = {k: v for k, v in exc.detail.items() if k != "message"}
|
||||
if details_data:
|
||||
details = details_data
|
||||
|
||||
code = "HTTP_ERROR"
|
||||
if exc.status_code == status.HTTP_409_CONFLICT:
|
||||
code = "EMAIL_ALREADY_EXISTS" if "email" in message.lower() else "CONFLICT"
|
||||
elif exc.status_code == status.HTTP_400_BAD_REQUEST:
|
||||
code = "BAD_REQUEST"
|
||||
elif exc.status_code == status.HTTP_401_UNAUTHORIZED:
|
||||
code = "UNAUTHORIZED"
|
||||
elif exc.status_code == status.HTTP_423_LOCKED:
|
||||
code = "USER_INACTIVE"
|
||||
elif exc.status_code == status.HTTP_403_FORBIDDEN:
|
||||
code = "FORBIDDEN"
|
||||
elif exc.status_code == status.HTTP_404_NOT_FOUND:
|
||||
code = "NOT_FOUND"
|
||||
if message == "Not Found":
|
||||
message = "Ресурс не найден"
|
||||
elif exc.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY:
|
||||
code = "VALIDATION_FAILED"
|
||||
|
||||
content = {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"traceId": trace_id,
|
||||
"timestamp": now_iso(),
|
||||
"path": request.url.path,
|
||||
}
|
||||
|
||||
if details:
|
||||
content["details"] = details
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
trace_id = getattr(request.state, "traceId", str(uuid.uuid4()))
|
||||
logger.exception("Unhandled exception on %s", request.url.path, exc_info=exc)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
"code": "INTERNAL_ERROR",
|
||||
"message": "Внутренняя ошибка сервера",
|
||||
"traceId": trace_id,
|
||||
"timestamp": now_iso(),
|
||||
"path": request.url.path,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
import bcrypt
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
pwd_bytes = password.encode("utf-8")
|
||||
salt = bcrypt.gensalt()
|
||||
return bcrypt.hashpw(pwd_bytes, salt).decode("utf-8")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
try:
|
||||
pwd_bytes = plain_password.encode("utf-8")
|
||||
hashed_bytes = hashed_password.encode("utf-8")
|
||||
return bcrypt.checkpw(pwd_bytes, hashed_bytes)
|
||||
except Exception:
|
||||
return False
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
|
||||
|
||||
_trace_id_ctx: ContextVar[str | None] = ContextVar("trace_id", default=None)
|
||||
_path_ctx: ContextVar[str | None] = ContextVar("path", default=None)
|
||||
_method_ctx: ContextVar[str | None] = ContextVar("method", default=None)
|
||||
_user_id_ctx: ContextVar[str | None] = ContextVar("user_id", default=None)
|
||||
|
||||
|
||||
def set_request_context(*, trace_id: str | None, path: str | None, method: str | None) -> None:
|
||||
_trace_id_ctx.set(trace_id)
|
||||
_path_ctx.set(path)
|
||||
_method_ctx.set(method)
|
||||
|
||||
|
||||
def set_user_context(*, user_id: str | None) -> None:
|
||||
_user_id_ctx.set(user_id)
|
||||
|
||||
|
||||
def clear_log_context() -> None:
|
||||
_trace_id_ctx.set(None)
|
||||
_path_ctx.set(None)
|
||||
_method_ctx.set(None)
|
||||
_user_id_ctx.set(None)
|
||||
|
||||
|
||||
def get_log_context() -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {}
|
||||
|
||||
trace_id = _trace_id_ctx.get()
|
||||
if trace_id:
|
||||
payload["trace_id"] = trace_id
|
||||
|
||||
path = _path_ctx.get()
|
||||
if path:
|
||||
payload["path"] = path
|
||||
|
||||
method = _method_ctx.get()
|
||||
if method:
|
||||
payload["method"] = method
|
||||
|
||||
user_id = _user_id_ctx.get()
|
||||
if user_id:
|
||||
payload["user_id"] = user_id
|
||||
|
||||
return payload
|
||||
@@ -0,0 +1,287 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
def build_capability_from_action(action: Any) -> dict[str, Any]:
|
||||
llm_result = _call_ollama_json(
|
||||
system_prompt=(
|
||||
"You convert one API action into one capability. "
|
||||
"Return only valid JSON with keys: "
|
||||
"name, description, input_schema, output_schema, data_format."
|
||||
),
|
||||
user_prompt=_build_prompt(action),
|
||||
)
|
||||
if llm_result is not None:
|
||||
normalized = _normalize_capability_payload(llm_result, action)
|
||||
normalized["llm_payload"] = llm_result
|
||||
return normalized
|
||||
|
||||
fallback = _build_fallback_capability(action)
|
||||
fallback["llm_payload"] = {
|
||||
"source": "fallback",
|
||||
"reason": "ollama_unavailable_or_invalid_response",
|
||||
}
|
||||
return fallback
|
||||
|
||||
|
||||
def chat_json(system_prompt: str, user_prompt: str) -> dict[str, Any] | None:
|
||||
return _call_ollama_json(system_prompt=system_prompt, user_prompt=user_prompt)
|
||||
|
||||
|
||||
def reset_model_session() -> None:
|
||||
host = os.getenv("OLLAMA_HOST", "http://178.154.193.191:8067").strip()
|
||||
model = os.getenv("OLLAMA_MODEL", "qwen2.5-coder:7b")
|
||||
headers = _load_headers()
|
||||
|
||||
try:
|
||||
from ollama import Client
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
client = Client(host=host, headers=headers or None)
|
||||
_reset_model_session(client=client, model=model)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def summarize_dialog_text(messages: list[dict[str, Any]]) -> str | None:
|
||||
prompt = (
|
||||
"Кратко сожми историю диалога на русском. "
|
||||
"Сохрани цель пользователя, ограничения, недостающие данные и важные решения. "
|
||||
"Ответь только текстом без markdown.\n\n"
|
||||
f"История:\n{json.dumps(messages, ensure_ascii=False)}"
|
||||
)
|
||||
payload = _call_ollama_json(
|
||||
system_prompt="Ты помощник, который сжимает диалоговый контекст для дальнейшего планирования.",
|
||||
user_prompt=prompt,
|
||||
)
|
||||
if isinstance(payload, dict):
|
||||
summary = payload.get("summary")
|
||||
if isinstance(summary, str) and summary.strip():
|
||||
return summary.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _call_ollama_json(system_prompt: str, user_prompt: str) -> dict[str, Any] | None:
|
||||
host = os.getenv("OLLAMA_HOST", "http://178.154.193.191:8067").strip()
|
||||
model = os.getenv("OLLAMA_MODEL", "qwen2.5-coder:7b")
|
||||
headers = _load_headers()
|
||||
|
||||
try:
|
||||
from ollama import Client
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
client = Client(host=host, headers=headers or None)
|
||||
response = client.chat(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
},
|
||||
],
|
||||
options={"temperature": 0},
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
content = _extract_message_content(response)
|
||||
if not content:
|
||||
return None
|
||||
payload = _parse_json_payload(content)
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
return payload
|
||||
|
||||
|
||||
def _build_prompt(action: Any) -> str:
|
||||
payload = {
|
||||
"operation_id": getattr(action, "operation_id", None),
|
||||
"method": getattr(action, "method", None).value if getattr(action, "method", None) else None,
|
||||
"path": getattr(action, "path", None),
|
||||
"base_url": getattr(action, "base_url", None),
|
||||
"summary": getattr(action, "summary", None),
|
||||
"description": getattr(action, "description", None),
|
||||
"tags": getattr(action, "tags", None),
|
||||
"parameters_schema": getattr(action, "parameters_schema", None),
|
||||
"request_body_schema": getattr(action, "request_body_schema", None),
|
||||
"response_schema": getattr(action, "response_schema", None),
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=True, indent=2)
|
||||
|
||||
|
||||
def _extract_message_content(response: Any) -> str | None:
|
||||
if isinstance(response, dict):
|
||||
message = response.get("message")
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
content = response.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return None
|
||||
|
||||
message = getattr(response, "message", None)
|
||||
if message is not None:
|
||||
content = getattr(message, "content", None)
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
content = getattr(response, "content", None)
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return None
|
||||
|
||||
|
||||
def _parse_json_payload(content: str) -> dict[str, Any] | None:
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
match = re.search(r"\{.*\}", content, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
try:
|
||||
return json.loads(match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_capability_payload(payload: dict[str, Any], action: Any) -> dict[str, Any]:
|
||||
fallback = _build_fallback_capability(action)
|
||||
return {
|
||||
"name": str(payload.get("name") or fallback["name"]),
|
||||
"description": str(payload.get("description") or fallback["description"]),
|
||||
"input_schema": _normalize_schema(payload.get("input_schema")) or fallback["input_schema"],
|
||||
"output_schema": _normalize_schema(payload.get("output_schema")) or fallback["output_schema"],
|
||||
"data_format": _normalize_data_format(payload.get("data_format")) or fallback["data_format"],
|
||||
}
|
||||
|
||||
|
||||
def _build_fallback_capability(action: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"name": _build_capability_name(action),
|
||||
"description": _build_capability_description(action),
|
||||
"input_schema": _build_input_schema(action),
|
||||
"output_schema": getattr(action, "response_schema", None),
|
||||
"data_format": _build_data_format(action),
|
||||
}
|
||||
|
||||
|
||||
def _build_capability_name(action: Any) -> str:
|
||||
operation_id = getattr(action, "operation_id", None)
|
||||
if operation_id:
|
||||
return str(operation_id)
|
||||
|
||||
method = getattr(action, "method", None)
|
||||
method_value = method.value.lower() if method is not None else "call"
|
||||
path = getattr(action, "path", "") or ""
|
||||
normalized_path = re.sub(r"[{}]", "", path).strip("/")
|
||||
normalized_path = re.sub(r"[^a-zA-Z0-9/]+", "_", normalized_path)
|
||||
normalized_path = normalized_path.replace("/", "_") or "root"
|
||||
return f"{method_value}_{normalized_path.lower()}"
|
||||
|
||||
|
||||
def _build_capability_description(action: Any) -> str:
|
||||
summary = getattr(action, "summary", None)
|
||||
description = getattr(action, "description", None)
|
||||
operation_id = getattr(action, "operation_id", None)
|
||||
return str(summary or description or operation_id or _build_capability_name(action))
|
||||
|
||||
|
||||
def _build_input_schema(action: Any) -> dict[str, Any] | None:
|
||||
parameters_schema = getattr(action, "parameters_schema", None)
|
||||
request_body_schema = getattr(action, "request_body_schema", None)
|
||||
|
||||
if parameters_schema and request_body_schema:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"parameters": parameters_schema,
|
||||
"request_body": request_body_schema,
|
||||
},
|
||||
}
|
||||
if parameters_schema:
|
||||
return parameters_schema
|
||||
if request_body_schema:
|
||||
return request_body_schema
|
||||
return None
|
||||
|
||||
|
||||
def _build_data_format(action: Any) -> dict[str, Any]:
|
||||
parameters_schema = getattr(action, "parameters_schema", None) or {}
|
||||
request_body_schema = getattr(action, "request_body_schema", None) or {}
|
||||
response_schema = getattr(action, "response_schema", None) or {}
|
||||
|
||||
parameter_locations: list[str] = []
|
||||
if isinstance(parameters_schema, dict):
|
||||
properties = parameters_schema.get("properties", {})
|
||||
if isinstance(properties, dict):
|
||||
for property_schema in properties.values():
|
||||
if not isinstance(property_schema, dict):
|
||||
continue
|
||||
location = property_schema.get("x-parameter-location")
|
||||
if isinstance(location, str) and location not in parameter_locations:
|
||||
parameter_locations.append(location)
|
||||
|
||||
request_content_type = request_body_schema.get("x-content-type") if isinstance(request_body_schema, dict) else None
|
||||
response_content_type = response_schema.get("x-content-type") if isinstance(response_schema, dict) else None
|
||||
|
||||
return {
|
||||
"parameter_locations": parameter_locations,
|
||||
"request_content_types": [request_content_type] if isinstance(request_content_type, str) else [],
|
||||
"request_schema_type": request_body_schema.get("type") if isinstance(request_body_schema, dict) else None,
|
||||
"response_content_types": [response_content_type] if isinstance(response_content_type, str) else [],
|
||||
"response_schema_types": [response_schema.get("type")] if isinstance(response_schema, dict) and isinstance(response_schema.get("type"), str) else [],
|
||||
}
|
||||
|
||||
|
||||
def _normalize_schema(value: Any) -> dict[str, Any] | None:
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_data_format(value: Any) -> dict[str, Any] | None:
|
||||
if not isinstance(value, dict):
|
||||
return None
|
||||
|
||||
return {
|
||||
"parameter_locations": _normalize_string_list(value.get("parameter_locations")),
|
||||
"request_content_types": _normalize_string_list(value.get("request_content_types")),
|
||||
"request_schema_type": value.get("request_schema_type"),
|
||||
"response_content_types": _normalize_string_list(value.get("response_content_types")),
|
||||
"response_schema_types": _normalize_string_list(value.get("response_schema_types")),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_string_list(value: Any) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return [str(item) for item in value if item is not None]
|
||||
return [str(value)]
|
||||
|
||||
|
||||
def _load_headers() -> dict[str, str]:
|
||||
headers_payload = os.getenv("OLLAMA_HEADERS_JSON")
|
||||
if not headers_payload:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(headers_payload)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
if not isinstance(parsed, dict):
|
||||
return {}
|
||||
return {str(key): str(value) for key, value in parsed.items()}
|
||||
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database.session import get_session
|
||||
from app.models import User, UserRole
|
||||
from app.utils.log_context import set_user_context
|
||||
|
||||
try:
|
||||
from jose import JWTError, jwt
|
||||
except ModuleNotFoundError:
|
||||
JWTError = Exception
|
||||
jwt = None
|
||||
|
||||
|
||||
JWT_SECRET = os.environ.get("JWT_SECRET", "super_secret_key_123")
|
||||
JWT_ALG = "HS256"
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
def create_access_token(*, sub: str, role: str) -> tuple[str, int]:
|
||||
expires_in = 3600
|
||||
if jwt is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="JWT support is not installed",
|
||||
)
|
||||
|
||||
expire = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
||||
payload = {"sub": str(sub), "role": role, "exp": expire}
|
||||
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALG)
|
||||
return token, expires_in
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
creds: HTTPAuthorizationCredentials | None = Depends(security),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> User:
|
||||
if creds is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if jwt is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="JWT support is not installed",
|
||||
)
|
||||
|
||||
token = creds.credentials
|
||||
auth_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALG])
|
||||
user_id_str: str | None = payload.get("sub")
|
||||
if user_id_str is None:
|
||||
raise auth_exception
|
||||
user_id = UUID(user_id_str)
|
||||
except (JWTError, ValueError):
|
||||
raise auth_exception
|
||||
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
raise auth_exception
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_423_LOCKED,
|
||||
detail="User account is deactivated",
|
||||
)
|
||||
|
||||
set_user_context(user_id=str(user.id))
|
||||
return user
|
||||
|
||||
|
||||
def check_permissions(allowed_roles: List[UserRole]):
|
||||
async def role_checker(current_user: User = Depends(get_current_user)):
|
||||
if current_user.role not in allowed_roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions",
|
||||
)
|
||||
return current_user
|
||||
|
||||
return role_checker
|
||||
Reference in New Issue
Block a user