Улучшить валидацию входа workflow и вынести схемы ответов.
Подключена pydantic-валидация input_schema для сценария, а модели успешного и ошибочного результата запуска вынесены в отдельный модуль для более явных boundary-контрактов.
This commit is contained in:
@@ -0,0 +1,31 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class RunError(BaseModel):
|
||||||
|
code: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class ScenarioRunBase(BaseModel):
|
||||||
|
scenario_id: str
|
||||||
|
status: Literal["success", "failed"]
|
||||||
|
input: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class ScenarioRunFailed(ScenarioRunBase):
|
||||||
|
status: Literal["failed"] = "failed"
|
||||||
|
scenario_name: str | None = None
|
||||||
|
error: RunError
|
||||||
|
|
||||||
|
|
||||||
|
class ScenarioRunSuccess(ScenarioRunBase):
|
||||||
|
status: Literal["success"] = "success"
|
||||||
|
workflow_name: str
|
||||||
|
scenario_name: str
|
||||||
|
result: dict[str, Any]
|
||||||
|
run_id: str | None = None
|
||||||
|
session_id: str | None = None
|
||||||
+78
-35
@@ -5,6 +5,8 @@ from typing import Any
|
|||||||
|
|
||||||
from agno.workflow.step import Step, StepInput, StepOutput
|
from agno.workflow.step import Step, StepInput, StepOutput
|
||||||
from agno.workflow.workflow import Workflow
|
from agno.workflow.workflow import Workflow
|
||||||
|
from pydantic import BaseModel, ValidationError, create_model
|
||||||
|
from src.schemas import RunError, ScenarioRunFailed, ScenarioRunSuccess
|
||||||
from src.scenario_store import ScenarioStoreError, load_scenario_definition
|
from src.scenario_store import ScenarioStoreError, load_scenario_definition
|
||||||
from src.stub_tools import (
|
from src.stub_tools import (
|
||||||
stub_extract_publication_date,
|
stub_extract_publication_date,
|
||||||
@@ -15,6 +17,7 @@ from src.stub_tools import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
_workflow_cache: dict[str, Workflow] = {}
|
_workflow_cache: dict[str, Workflow] = {}
|
||||||
|
_workflow_input_schemas: dict[str, type[BaseModel]] = {}
|
||||||
|
|
||||||
|
|
||||||
def _json_loads(raw: str | None) -> dict[str, Any]:
|
def _json_loads(raw: str | None) -> dict[str, Any]:
|
||||||
@@ -33,8 +36,43 @@ def _as_json_step_output(payload: dict[str, Any]) -> StepOutput:
|
|||||||
return StepOutput(content=json.dumps(payload, ensure_ascii=False))
|
return StepOutput(content=json.dumps(payload, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_input_url(step_input_value: Any) -> str:
|
||||||
|
if isinstance(step_input_value, dict):
|
||||||
|
return str(step_input_value.get("url", "")).strip()
|
||||||
|
return str(step_input_value).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _build_input_schema_model(scenario: dict[str, Any]) -> type[BaseModel] | None:
|
||||||
|
input_schema = scenario.get("input_schema")
|
||||||
|
if not isinstance(input_schema, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
properties = input_schema.get("properties")
|
||||||
|
if not isinstance(properties, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
required_raw = input_schema.get("required", [])
|
||||||
|
required_fields = set(required_raw) if isinstance(required_raw, list) else set()
|
||||||
|
fields: dict[str, tuple[type[str], Any]] = {}
|
||||||
|
|
||||||
|
for field_name, field_schema in properties.items():
|
||||||
|
if not isinstance(field_name, str) or not isinstance(field_schema, dict):
|
||||||
|
continue
|
||||||
|
if field_schema.get("type") != "string":
|
||||||
|
continue
|
||||||
|
default_value = ... if field_name in required_fields else ""
|
||||||
|
fields[field_name] = (str, default_value)
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return create_model(f"{scenario.get('scenario_id', 'Scenario')}Input", **fields)
|
||||||
|
|
||||||
|
|
||||||
async def _search_news_sources_executor(step_input: StepInput) -> StepOutput:
|
async def _search_news_sources_executor(step_input: StepInput) -> StepOutput:
|
||||||
input_url = str(step_input.input)
|
input_url = _extract_input_url(step_input.input)
|
||||||
|
if not input_url:
|
||||||
|
return StepOutput(content="search_news_sources failed: input.url is empty", success=False)
|
||||||
search_result = await stub_search_news_sources(url=input_url)
|
search_result = await stub_search_news_sources(url=input_url)
|
||||||
return _as_json_step_output(search_result)
|
return _as_json_step_output(search_result)
|
||||||
|
|
||||||
@@ -140,11 +178,15 @@ def get_workflow_for_scenario(scenario_id: str, scenario: dict[str, Any]) -> Wor
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
input_schema_model = _build_input_schema_model(scenario)
|
||||||
workflow = Workflow(
|
workflow = Workflow(
|
||||||
name=scenario_id,
|
name=scenario_id,
|
||||||
description=str(scenario.get("description", "")),
|
description=str(scenario.get("description", "")),
|
||||||
steps=workflow_steps,
|
steps=workflow_steps,
|
||||||
|
input_schema=input_schema_model,
|
||||||
)
|
)
|
||||||
|
if input_schema_model is not None:
|
||||||
|
_workflow_input_schemas[scenario_id] = input_schema_model
|
||||||
_workflow_cache[scenario_id] = workflow
|
_workflow_cache[scenario_id] = workflow
|
||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
@@ -156,31 +198,29 @@ async def run_scenario_workflow(
|
|||||||
try:
|
try:
|
||||||
scenario = load_scenario_definition(scenario_id)
|
scenario = load_scenario_definition(scenario_id)
|
||||||
except ScenarioStoreError as exc:
|
except ScenarioStoreError as exc:
|
||||||
return {
|
return ScenarioRunFailed(
|
||||||
"scenario_id": scenario_id,
|
scenario_id=scenario_id,
|
||||||
"status": "failed",
|
input=input_data,
|
||||||
"input": input_data,
|
error=RunError(code="unknown_scenario", message=str(exc)),
|
||||||
"error": {
|
).model_dump()
|
||||||
"code": "unknown_scenario",
|
|
||||||
"message": str(exc),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
input_url = str(input_data.get("url", "")).strip()
|
|
||||||
if not input_url:
|
|
||||||
return {
|
|
||||||
"scenario_id": scenario_id,
|
|
||||||
"status": "failed",
|
|
||||||
"scenario_name": str(scenario.get("name", scenario_id)),
|
|
||||||
"input": input_data,
|
|
||||||
"error": {
|
|
||||||
"code": "invalid_input",
|
|
||||||
"message": "Current scenario expects input.url as non-empty string.",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
workflow = get_workflow_for_scenario(scenario_id=scenario_id, scenario=scenario)
|
workflow = get_workflow_for_scenario(scenario_id=scenario_id, scenario=scenario)
|
||||||
run_output = await workflow.arun(input=input_url)
|
input_schema_model = _workflow_input_schemas.get(scenario_id)
|
||||||
|
if input_schema_model is not None:
|
||||||
|
try:
|
||||||
|
input_schema_model.model_validate(input_data)
|
||||||
|
except ValidationError as exc:
|
||||||
|
return ScenarioRunFailed(
|
||||||
|
scenario_id=scenario_id,
|
||||||
|
scenario_name=str(scenario.get("name", scenario_id)),
|
||||||
|
input=input_data,
|
||||||
|
error=RunError(
|
||||||
|
code="invalid_input",
|
||||||
|
message=f"Input does not match scenario input_schema: {exc}",
|
||||||
|
),
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
run_output = await workflow.arun(input=input_data)
|
||||||
|
|
||||||
content: Any = run_output.content if hasattr(run_output, "content") else {}
|
content: Any = run_output.content if hasattr(run_output, "content") else {}
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -189,16 +229,19 @@ async def run_scenario_workflow(
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
content = {"raw_content": content}
|
content = {"raw_content": content}
|
||||||
|
|
||||||
response: dict[str, Any] = {
|
run_id: str | None = None
|
||||||
"scenario_id": scenario_id,
|
session_id: str | None = None
|
||||||
"workflow_name": workflow.name,
|
|
||||||
"scenario_name": str(scenario.get("name", scenario_id)),
|
|
||||||
"status": "success",
|
|
||||||
"input": input_data,
|
|
||||||
"result": content,
|
|
||||||
}
|
|
||||||
if hasattr(run_output, "run_id"):
|
if hasattr(run_output, "run_id"):
|
||||||
response["run_id"] = str(getattr(run_output, "run_id"))
|
run_id = str(getattr(run_output, "run_id"))
|
||||||
if hasattr(run_output, "session_id"):
|
if hasattr(run_output, "session_id"):
|
||||||
response["session_id"] = str(getattr(run_output, "session_id"))
|
session_id = str(getattr(run_output, "session_id"))
|
||||||
return response
|
|
||||||
|
return ScenarioRunSuccess(
|
||||||
|
scenario_id=scenario_id,
|
||||||
|
workflow_name=workflow.name,
|
||||||
|
scenario_name=str(scenario.get("name", scenario_id)),
|
||||||
|
input=input_data,
|
||||||
|
result=content if isinstance(content, dict) else {"raw_content": str(content)},
|
||||||
|
run_id=run_id,
|
||||||
|
session_id=session_id,
|
||||||
|
).model_dump(exclude_none=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user