Промежуточный вариант: ужесточить planner recovery и fail-fast workflow.

Перевел планирование аргументов на строгий json_schema response_format, добавил сценарий с битыми полями для проверки восстановления и остановку workflow на первой ошибке шага. Сейчас используется Polza.ai.
This commit is contained in:
Barabashka
2026-04-22 17:44:35 +03:00
parent ad828885e3
commit 5ca49821ba
3 changed files with 342 additions and 86 deletions
+2 -1
View File
@@ -1,5 +1,6 @@
{ {
"scenarios": { "scenarios": {
"news_source_discovery_v1": "news_source_discovery/v1.json" "news_source_discovery_v1": "news_source_discovery/v1.json",
"news_source_discovery_v1_planner_repair": "news_source_discovery/v1_planner_repair.json"
} }
} }
@@ -0,0 +1,117 @@
{
"schema_version": "1",
"scenario_id": "news_source_discovery_v1_planner_repair",
"name": "News Source Discovery V1 Planner Repair",
"description": "Test scenario with intentionally wrong input paths repaired by planner.",
"input_schema": {
"type": "object",
"required": [
"url"
],
"properties": {
"url": {
"type": "string",
"description": "URL of source news article"
}
}
},
"steps": [
{
"name": "search_news_sources",
"type": "tool",
"tool": "search_news_sources",
"input": {
"url": {
"from": "input.url"
}
},
"required_input_fields": [
"url"
]
},
{
"name": "parse_articles_batch",
"type": "tool",
"tool": "parse_article",
"foreach": {
"from": "steps.search_news_sources.payload.items",
"as": "item"
},
"input": {
"url": {
"from": "item.link"
}
},
"required_input_fields": [
"url"
],
"collect": {
"url": {
"from": "tool.payload.url"
},
"title": {
"from": "tool.payload.title"
},
"text": {
"from": "tool.payload.text"
}
},
"collect_key": "items"
},
{
"name": "extract_publication_date_batch",
"type": "tool",
"tool": "extract_publication_date",
"foreach": {
"from": "steps.parse_articles_batch.payload.items",
"as": "item"
},
"input": {
"article_text": {
"from": "item.body"
}
},
"required_input_fields": [
"article_text"
],
"collect": {
"url": {
"from": "item.url"
},
"title": {
"from": "item.title"
},
"published_at": {
"from": "tool.payload.published_at"
}
},
"collect_key": "items"
},
{
"name": "rank_sources_by_date",
"type": "tool",
"tool": "rank_sources_by_date",
"input": {
"items": {
"from": "steps.extract_publication_date_batch.payload.items"
}
},
"required_input_fields": [
"items"
]
},
{
"name": "generate_summary",
"type": "tool",
"tool": "generate_summary",
"input": {
"items": {
"from": "steps.rank_sources_by_date.payload.items_ranked_typo"
}
},
"required_input_fields": [
"items"
]
}
]
}
+223 -85
View File
@@ -7,24 +7,15 @@ import json
import os import os
from typing import Any from typing import Any
from agno.agent import Agent
from agno.models.openai import OpenAIChat
from agno.workflow.step import Step, StepInput, StepOutput from agno.workflow.step import Step, StepInput, StepOutput
from agno.workflow.workflow import Workflow from agno.workflow.workflow import Workflow
from pydantic import BaseModel, Field from openai import AsyncOpenAI
from src.mcp_client import call_mcp_tool from src.mcp_client import call_mcp_tool
from src.schemas import RunError, ScenarioRunResponse, StepState from src.schemas import RunError, ScenarioRunResponse, StepState
from src.scenario_store import ScenarioStoreError, load_scenario_definition from src.scenario_store import ScenarioStoreError, load_scenario_definition
_planner_client: AsyncOpenAI | None = None
class McpArgumentsPlan(BaseModel):
"""Structured planner output for one MCP tool call."""
arguments: dict[str, Any] = Field(default_factory=dict)
_planner_agent: Agent | None = None
def _env_float(name: str, default: float) -> float: def _env_float(name: str, default: float) -> float:
@@ -34,47 +25,29 @@ def _env_float(name: str, default: float) -> float:
return float(value) return float(value)
def _env_int(name: str, default: int) -> int:
value = os.getenv(name)
if value is None:
return default
return int(value)
def _utc_now_iso() -> str: def _utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat() return datetime.now(timezone.utc).isoformat()
def get_shared_step_planner_agent() -> Agent: def get_shared_step_planner_client() -> AsyncOpenAI:
""" global _planner_client
Create one reusable planner agent for all workflow steps. if _planner_client is not None:
return _planner_client
This agent never calls MCP directly. It only prepares arguments
for a fixed MCP method selected by the workflow step.
"""
global _planner_agent
if _planner_agent is not None:
return _planner_agent
model_id = os.getenv("POLZA_MODEL_ID", "google/gemma-4-31b-it")
polza_base_url = os.getenv("POLZA_BASE_URL", "https://api.polza.ai/v1") polza_base_url = os.getenv("POLZA_BASE_URL", "https://api.polza.ai/v1")
polza_api_key = os.getenv("POLZA_API_KEY") or os.getenv("OPENAI_API_KEY") polza_api_key = os.getenv("POLZA_API_KEY") or os.getenv("OPENAI_API_KEY")
temperature = _env_float("POLZA_TEMPERATURE", 0.0) _planner_client = AsyncOpenAI(
llm = OpenAIChat(
id=model_id,
api_key=polza_api_key,
base_url=polza_base_url, base_url=polza_base_url,
temperature=temperature, api_key=polza_api_key,
) )
_planner_agent = Agent( return _planner_client
id="workflow-step-planner",
model=llm,
output_schema=McpArgumentsPlan,
markdown=False,
debug_mode=False,
instructions=[
"You are a strict tool-input planner.",
"You receive step metadata and current workflow context.",
"Return only arguments that should be sent to MCP tool.",
"Do not add extra keys that are unrelated to the tool.",
"Do not invent values if they are absent in context.",
],
)
return _planner_agent
def _resolve_path(scope: dict[str, Any], path: str) -> Any: def _resolve_path(scope: dict[str, Any], path: str) -> Any:
@@ -104,13 +77,84 @@ def _validate_required_fields(
required_fields: list[str], required_fields: list[str],
step_name: str, step_name: str,
) -> None: ) -> None:
missing_fields: list[str] = []
for field in required_fields: for field in required_fields:
value = arguments.get(field) value = arguments.get(field)
if isinstance(value, str) and value.strip(): if isinstance(value, str) and value.strip():
continue continue
if value not in (None, "", [], {}): if value not in (None, "", [], {}):
continue continue
raise ValueError(f"{step_name}: input.{field} is empty") missing_fields.append(field)
if missing_fields:
fields_str = ", ".join(missing_fields)
raise ValueError(f"{step_name}: missing required fields: {fields_str}")
def _missing_required_fields(arguments: dict[str, Any], required_fields: list[str]) -> list[str]:
missing_fields: list[str] = []
for field in required_fields:
value = arguments.get(field)
if isinstance(value, str) and value.strip():
continue
if value not in (None, "", [], {}):
continue
missing_fields.append(field)
return missing_fields
def _build_arguments_schema(required_fields: list[str]) -> dict[str, Any]:
properties = {field: {"type": "any"} for field in required_fields}
return {
"type": "object",
"required": required_fields,
"properties": properties,
}
def _build_polza_response_schema(required_fields: list[str]) -> dict[str, Any]:
value_schema: dict[str, Any] = {
"type": ["string", "number", "boolean", "array", "object", "null"]
}
arguments_properties = {field: value_schema for field in required_fields}
return {
"name": "mcp_arguments",
"strict": True,
"schema": {
"type": "object",
"properties": {
"arguments": {
"type": "object",
"properties": arguments_properties,
"required": required_fields,
"additionalProperties": True,
}
},
"required": ["arguments"],
"additionalProperties": False,
},
}
def _extract_planned_arguments(content: Any) -> dict[str, Any]:
candidate: Any = content
if isinstance(candidate, str):
text = candidate.strip()
if text.startswith("```"):
text = text.strip("`").strip()
if text.startswith("json"):
text = text[4:].strip()
try:
candidate = json.loads(text)
except json.JSONDecodeError:
return {}
if isinstance(candidate, dict):
if isinstance(candidate.get("arguments"), dict):
return candidate["arguments"]
# Some models return the arguments object directly.
return candidate
return {}
class McpWorkflowRunner: class McpWorkflowRunner:
@@ -122,9 +166,9 @@ class McpWorkflowRunner:
- request/response persisted in run context - request/response persisted in run context
""" """
def __init__(self, planner_agent: Agent | None = None) -> None: def __init__(self) -> None:
self._planner_agent = planner_agent or get_shared_step_planner_agent()
self._workflow_cache: dict[str, Workflow] = {} self._workflow_cache: dict[str, Workflow] = {}
self._planner_repair_attempts = _env_int("PLANNER_REPAIR_ATTEMPTS", 3)
self._run_state_ctx: ContextVar[dict[str, Any] | None] = ContextVar( self._run_state_ctx: ContextVar[dict[str, Any] | None] = ContextVar(
"mcp_workflow_run_state", "mcp_workflow_run_state",
default=None, default=None,
@@ -151,27 +195,81 @@ class McpWorkflowRunner:
base_arguments: dict[str, Any], base_arguments: dict[str, Any],
required_fields: list[str], required_fields: list[str],
scope: dict[str, Any], scope: dict[str, Any],
planner_cache: dict[str, dict[str, Any]] | None = None,
missing_fields: list[str] | None = None,
attempt_no: int = 1,
) -> dict[str, Any]: ) -> dict[str, Any]:
cache_key: str | None = None
if planner_cache is not None:
try:
cache_payload = {
"tool_name": tool_name,
"base_arguments": base_arguments,
"required_fields": required_fields,
"missing_fields": missing_fields or [],
"attempt_no": attempt_no,
}
cache_key = json.dumps(cache_payload, sort_keys=True, ensure_ascii=False)
except TypeError:
cache_key = None
if cache_key is not None and cache_key in planner_cache:
return deepcopy(planner_cache[cache_key])
planner_context = {
"input": scope.get("input", {}),
"steps": scope.get("steps", {}),
}
for key, value in scope.items():
if key in {"input", "steps"}:
continue
planner_context[key] = value
prompt = { prompt = {
"task": "Prepare MCP arguments for this step.", "task": "Prepare MCP arguments for this step.",
"step_name": step_name, "step_name": step_name,
"tool_name": tool_name, "tool_name": tool_name,
"required_fields": required_fields, "required_fields": required_fields,
"base_arguments": base_arguments, "base_arguments": base_arguments,
"context": { "missing_fields": missing_fields or [],
"input": scope.get("input", {}), "repair_attempt": attempt_no,
"steps": scope.get("steps", {}), "arguments_schema": _build_arguments_schema(required_fields),
"context": planner_context,
"response_contract": {
"must_return": {"arguments": "object"},
"must_include_fields": missing_fields or [],
"forbidden": "extra unrelated keys",
}, },
"output": "Return arguments object only.", "output": (
"Return only JSON object with key 'arguments'. "
"If missing_fields is not empty, fill every missing field from context."
),
} }
run_output = await self._planner_agent.arun(json.dumps(prompt, ensure_ascii=False)) prompt_json = json.dumps(prompt, ensure_ascii=False)
content = run_output.content if hasattr(run_output, "content") else {} planned: dict[str, Any] = {}
if isinstance(content, McpArgumentsPlan): # Primary path: strict structured output via Polza response_format/json_schema.
planned = content.arguments try:
elif isinstance(content, dict): completion = await get_shared_step_planner_client().chat.completions.create(
planned = content.get("arguments", {}) model=os.getenv("POLZA_MODEL_ID", "google/gemma-4-31b-it"),
else: messages=[
{
"role": "system",
"content": (
"You are a tool-input planner. "
"Return only JSON that matches the provided schema."
),
},
{"role": "user", "content": prompt_json},
],
response_format={
"type": "json_schema",
"json_schema": _build_polza_response_schema(required_fields),
},
temperature=_env_float("POLZA_TEMPERATURE", 0.0),
)
raw_content = completion.choices[0].message.content if completion.choices else ""
planned = _extract_planned_arguments(raw_content)
except Exception:
planned = {} planned = {}
if not isinstance(planned, dict): if not isinstance(planned, dict):
@@ -180,6 +278,8 @@ class McpWorkflowRunner:
# Allow planner to override/fill base arguments while keeping known defaults. # Allow planner to override/fill base arguments while keeping known defaults.
merged = deepcopy(base_arguments) merged = deepcopy(base_arguments)
merged.update(planned) merged.update(planned)
if planner_cache is not None and cache_key is not None:
planner_cache[cache_key] = deepcopy(merged)
return merged return merged
def _build_tool_step_executor(self, step_spec: dict[str, Any]): def _build_tool_step_executor(self, step_spec: dict[str, Any]):
@@ -206,6 +306,38 @@ class McpWorkflowRunner:
run_state = self._get_run_state() run_state = self._get_run_state()
scope = self._build_scope() scope = self._build_scope()
step_started_at = _utc_now_iso() step_started_at = _utc_now_iso()
planner_cache: dict[str, dict[str, Any]] = {}
async def _prepare_arguments(
*,
local_scope: dict[str, Any],
local_base_arguments: dict[str, Any],
) -> dict[str, Any]:
final_arguments = deepcopy(local_base_arguments)
for repair_attempt in range(1, self._planner_repair_attempts + 1):
missing_fields = _missing_required_fields(final_arguments, required_fields)
if not missing_fields:
break
final_arguments = await self._plan_arguments(
step_name=step_name,
tool_name=tool_name,
base_arguments=final_arguments,
required_fields=required_fields,
scope=local_scope,
planner_cache=planner_cache,
missing_fields=missing_fields,
attempt_no=repair_attempt,
)
_validate_required_fields(final_arguments, required_fields, step_name)
return final_arguments
async def _call_tool_with_repair(
*,
initial_arguments: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
final_arguments = deepcopy(initial_arguments)
tool_response = await call_mcp_tool(tool_name, final_arguments)
return tool_response, final_arguments
try: try:
tool_calls = run_state.setdefault("tool_calls", []) tool_calls = run_state.setdefault("tool_calls", [])
@@ -227,17 +359,13 @@ class McpWorkflowRunner:
resolved = _resolve_template(input_template, iteration_scope) resolved = _resolve_template(input_template, iteration_scope)
base_arguments = resolved if isinstance(resolved, dict) else {} base_arguments = resolved if isinstance(resolved, dict) else {}
final_arguments = await _prepare_arguments(
final_arguments = await self._plan_arguments( local_scope=iteration_scope,
step_name=step_name, local_base_arguments=base_arguments,
tool_name=tool_name, )
base_arguments=base_arguments, tool_response, final_arguments = await _call_tool_with_repair(
required_fields=required_fields, initial_arguments=final_arguments,
scope=iteration_scope,
) )
_validate_required_fields(final_arguments, required_fields, step_name)
tool_response = await call_mcp_tool(tool_name, final_arguments)
tool_calls.append( tool_calls.append(
{ {
"step_name": step_name, "step_name": step_name,
@@ -261,7 +389,7 @@ class McpWorkflowRunner:
step_payload = { step_payload = {
"ok": True, "ok": True,
"tool_name": step_name, "tool_name": tool_name,
"payload": {collect_key: collected_items}, "payload": {collect_key: collected_items},
"request": {"foreach_from": source_path, "count": len(iterable)}, "request": {"foreach_from": source_path, "count": len(iterable)},
"received_at": _utc_now_iso(), "received_at": _utc_now_iso(),
@@ -271,17 +399,13 @@ class McpWorkflowRunner:
else: else:
resolved = _resolve_template(input_template, scope) resolved = _resolve_template(input_template, scope)
base_arguments = resolved if isinstance(resolved, dict) else {} base_arguments = resolved if isinstance(resolved, dict) else {}
final_arguments = await _prepare_arguments(
final_arguments = await self._plan_arguments( local_scope=scope,
step_name=step_name, local_base_arguments=base_arguments,
tool_name=tool_name, )
base_arguments=base_arguments, tool_response, final_arguments = await _call_tool_with_repair(
required_fields=required_fields, initial_arguments=final_arguments,
scope=scope,
) )
_validate_required_fields(final_arguments, required_fields, step_name)
tool_response = await call_mcp_tool(tool_name, final_arguments)
step_payload = { step_payload = {
"ok": bool(tool_response.get("ok", True)), "ok": bool(tool_response.get("ok", True)),
"tool_name": tool_name, "tool_name": tool_name,
@@ -326,10 +450,7 @@ class McpWorkflowRunner:
"error": str(exc), "error": str(exc),
} }
) )
return StepOutput( raise RuntimeError(f"{step_name} failed: {exc}") from exc
content=json.dumps(error_payload, ensure_ascii=False),
success=False,
)
return _executor return _executor
@@ -360,6 +481,8 @@ class McpWorkflowRunner:
name=step_name, name=step_name,
description=str(raw_step.get("description", step_name)), description=str(raw_step.get("description", step_name)),
executor=executor, executor=executor,
max_retries=0,
on_error="fail",
) )
) )
@@ -383,8 +506,11 @@ class McpWorkflowRunner:
token = self._run_state_ctx.set(initial_state) token = self._run_state_ctx.set(initial_state)
run_state = initial_state run_state = initial_state
run_output: Any = None run_output: Any = None
workflow_error: str | None = None
try: try:
run_output = await workflow.arun(input=input_data) run_output = await workflow.arun(input=input_data)
except Exception as exc:
workflow_error = str(exc)
finally: finally:
captured = self._run_state_ctx.get() captured = self._run_state_ctx.get()
if isinstance(captured, dict): if isinstance(captured, dict):
@@ -397,13 +523,25 @@ class McpWorkflowRunner:
content = json.loads(content) content = json.loads(content)
except json.JSONDecodeError: except json.JSONDecodeError:
content = {"raw_content": content} content = {"raw_content": content}
if content is None:
step_payloads = run_state.get("steps", {})
if isinstance(step_payloads, dict):
for payload in reversed(list(step_payloads.values())):
if isinstance(payload, dict) and not bool(payload.get("ok", True)):
content = deepcopy(payload)
break
if content is None and workflow_error is not None:
content = {"error": workflow_error}
status = "success"
if workflow_error is not None:
status = "failed"
elif run_output is not None and not bool(getattr(run_output, "success", True)):
status = "failed"
return { return {
"scenario_id": scenario_id, "scenario_id": scenario_id,
"workflow_name": workflow.name, "workflow_name": workflow.name,
"status": "success" "status": status,
if getattr(run_output, "success", True)
else "failed",
"input": input_data, "input": input_data,
"final_result": content if isinstance(content, dict) else {"raw_content": content}, "final_result": content if isinstance(content, dict) else {"raw_content": content},
"steps": run_state.get("steps", {}), "steps": run_state.get("steps", {}),