Улучшить валидацию входа workflow и вынести схемы ответов.
Подключена pydantic-валидация input_schema для сценария, а модели успешного и ошибочного результата запуска вынесены в отдельный модуль для более явных boundary-контрактов.
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RunError(BaseModel):
|
||||
code: str
|
||||
message: str
|
||||
|
||||
|
||||
class ScenarioRunBase(BaseModel):
|
||||
scenario_id: str
|
||||
status: Literal["success", "failed"]
|
||||
input: dict[str, Any]
|
||||
|
||||
|
||||
class ScenarioRunFailed(ScenarioRunBase):
|
||||
status: Literal["failed"] = "failed"
|
||||
scenario_name: str | None = None
|
||||
error: RunError
|
||||
|
||||
|
||||
class ScenarioRunSuccess(ScenarioRunBase):
|
||||
status: Literal["success"] = "success"
|
||||
workflow_name: str
|
||||
scenario_name: str
|
||||
result: dict[str, Any]
|
||||
run_id: str | None = None
|
||||
session_id: str | None = None
|
||||
+78
-35
@@ -5,6 +5,8 @@ from typing import Any
|
||||
|
||||
from agno.workflow.step import Step, StepInput, StepOutput
|
||||
from agno.workflow.workflow import Workflow
|
||||
from pydantic import BaseModel, ValidationError, create_model
|
||||
from src.schemas import RunError, ScenarioRunFailed, ScenarioRunSuccess
|
||||
from src.scenario_store import ScenarioStoreError, load_scenario_definition
|
||||
from src.stub_tools import (
|
||||
stub_extract_publication_date,
|
||||
@@ -15,6 +17,7 @@ from src.stub_tools import (
|
||||
)
|
||||
|
||||
_workflow_cache: dict[str, Workflow] = {}
|
||||
_workflow_input_schemas: dict[str, type[BaseModel]] = {}
|
||||
|
||||
|
||||
def _json_loads(raw: str | None) -> dict[str, Any]:
|
||||
@@ -33,8 +36,43 @@ def _as_json_step_output(payload: dict[str, Any]) -> StepOutput:
|
||||
return StepOutput(content=json.dumps(payload, ensure_ascii=False))
|
||||
|
||||
|
||||
def _extract_input_url(step_input_value: Any) -> str:
|
||||
if isinstance(step_input_value, dict):
|
||||
return str(step_input_value.get("url", "")).strip()
|
||||
return str(step_input_value).strip()
|
||||
|
||||
|
||||
def _build_input_schema_model(scenario: dict[str, Any]) -> type[BaseModel] | None:
|
||||
input_schema = scenario.get("input_schema")
|
||||
if not isinstance(input_schema, dict):
|
||||
return None
|
||||
|
||||
properties = input_schema.get("properties")
|
||||
if not isinstance(properties, dict):
|
||||
return None
|
||||
|
||||
required_raw = input_schema.get("required", [])
|
||||
required_fields = set(required_raw) if isinstance(required_raw, list) else set()
|
||||
fields: dict[str, tuple[type[str], Any]] = {}
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
if not isinstance(field_name, str) or not isinstance(field_schema, dict):
|
||||
continue
|
||||
if field_schema.get("type") != "string":
|
||||
continue
|
||||
default_value = ... if field_name in required_fields else ""
|
||||
fields[field_name] = (str, default_value)
|
||||
|
||||
if not fields:
|
||||
return None
|
||||
|
||||
return create_model(f"{scenario.get('scenario_id', 'Scenario')}Input", **fields)
|
||||
|
||||
|
||||
async def _search_news_sources_executor(step_input: StepInput) -> StepOutput:
|
||||
input_url = str(step_input.input)
|
||||
input_url = _extract_input_url(step_input.input)
|
||||
if not input_url:
|
||||
return StepOutput(content="search_news_sources failed: input.url is empty", success=False)
|
||||
search_result = await stub_search_news_sources(url=input_url)
|
||||
return _as_json_step_output(search_result)
|
||||
|
||||
@@ -140,11 +178,15 @@ def get_workflow_for_scenario(scenario_id: str, scenario: dict[str, Any]) -> Wor
|
||||
)
|
||||
)
|
||||
|
||||
input_schema_model = _build_input_schema_model(scenario)
|
||||
workflow = Workflow(
|
||||
name=scenario_id,
|
||||
description=str(scenario.get("description", "")),
|
||||
steps=workflow_steps,
|
||||
input_schema=input_schema_model,
|
||||
)
|
||||
if input_schema_model is not None:
|
||||
_workflow_input_schemas[scenario_id] = input_schema_model
|
||||
_workflow_cache[scenario_id] = workflow
|
||||
return workflow
|
||||
|
||||
@@ -156,31 +198,29 @@ async def run_scenario_workflow(
|
||||
try:
|
||||
scenario = load_scenario_definition(scenario_id)
|
||||
except ScenarioStoreError as exc:
|
||||
return {
|
||||
"scenario_id": scenario_id,
|
||||
"status": "failed",
|
||||
"input": input_data,
|
||||
"error": {
|
||||
"code": "unknown_scenario",
|
||||
"message": str(exc),
|
||||
},
|
||||
}
|
||||
|
||||
input_url = str(input_data.get("url", "")).strip()
|
||||
if not input_url:
|
||||
return {
|
||||
"scenario_id": scenario_id,
|
||||
"status": "failed",
|
||||
"scenario_name": str(scenario.get("name", scenario_id)),
|
||||
"input": input_data,
|
||||
"error": {
|
||||
"code": "invalid_input",
|
||||
"message": "Current scenario expects input.url as non-empty string.",
|
||||
},
|
||||
}
|
||||
return ScenarioRunFailed(
|
||||
scenario_id=scenario_id,
|
||||
input=input_data,
|
||||
error=RunError(code="unknown_scenario", message=str(exc)),
|
||||
).model_dump()
|
||||
|
||||
workflow = get_workflow_for_scenario(scenario_id=scenario_id, scenario=scenario)
|
||||
run_output = await workflow.arun(input=input_url)
|
||||
input_schema_model = _workflow_input_schemas.get(scenario_id)
|
||||
if input_schema_model is not None:
|
||||
try:
|
||||
input_schema_model.model_validate(input_data)
|
||||
except ValidationError as exc:
|
||||
return ScenarioRunFailed(
|
||||
scenario_id=scenario_id,
|
||||
scenario_name=str(scenario.get("name", scenario_id)),
|
||||
input=input_data,
|
||||
error=RunError(
|
||||
code="invalid_input",
|
||||
message=f"Input does not match scenario input_schema: {exc}",
|
||||
),
|
||||
).model_dump()
|
||||
|
||||
run_output = await workflow.arun(input=input_data)
|
||||
|
||||
content: Any = run_output.content if hasattr(run_output, "content") else {}
|
||||
if isinstance(content, str):
|
||||
@@ -189,16 +229,19 @@ async def run_scenario_workflow(
|
||||
except json.JSONDecodeError:
|
||||
content = {"raw_content": content}
|
||||
|
||||
response: dict[str, Any] = {
|
||||
"scenario_id": scenario_id,
|
||||
"workflow_name": workflow.name,
|
||||
"scenario_name": str(scenario.get("name", scenario_id)),
|
||||
"status": "success",
|
||||
"input": input_data,
|
||||
"result": content,
|
||||
}
|
||||
run_id: str | None = None
|
||||
session_id: str | None = None
|
||||
if hasattr(run_output, "run_id"):
|
||||
response["run_id"] = str(getattr(run_output, "run_id"))
|
||||
run_id = str(getattr(run_output, "run_id"))
|
||||
if hasattr(run_output, "session_id"):
|
||||
response["session_id"] = str(getattr(run_output, "session_id"))
|
||||
return response
|
||||
session_id = str(getattr(run_output, "session_id"))
|
||||
|
||||
return ScenarioRunSuccess(
|
||||
scenario_id=scenario_id,
|
||||
workflow_name=workflow.name,
|
||||
scenario_name=str(scenario.get("name", scenario_id)),
|
||||
input=input_data,
|
||||
result=content if isinstance(content, dict) else {"raw_content": str(content)},
|
||||
run_id=run_id,
|
||||
session_id=session_id,
|
||||
).model_dump(exclude_none=True)
|
||||
|
||||
Reference in New Issue
Block a user