from __future__ import annotations from contextvars import ContextVar from copy import deepcopy from datetime import datetime, timezone import json import os from typing import Any from agno.workflow.step import Step, StepInput, StepOutput from agno.workflow.workflow import Workflow from openai import AsyncOpenAI from src.mcp_client import call_mcp_tool from src.schemas import RunError, ScenarioRunResponse, StepState from src.scenario_store import ScenarioStoreError, load_scenario_definition _planner_client: AsyncOpenAI | None = None def _env_float(name: str, default: float) -> float: value = os.getenv(name) if value is None: return default return float(value) def _env_int(name: str, default: int) -> int: value = os.getenv(name) if value is None: return default return int(value) def _utc_now_iso() -> str: return datetime.now(timezone.utc).isoformat() def get_shared_step_planner_client() -> AsyncOpenAI: global _planner_client if _planner_client is not None: return _planner_client polza_base_url = os.getenv("POLZA_BASE_URL", "https://api.polza.ai/v1") polza_api_key = os.getenv("POLZA_API_KEY") or os.getenv("OPENAI_API_KEY") _planner_client = AsyncOpenAI( base_url=polza_base_url, api_key=polza_api_key, ) return _planner_client def _resolve_path(scope: dict[str, Any], path: str) -> Any: value: Any = scope for segment in path.split("."): key = segment.strip() if not key: continue if not isinstance(value, dict): return None value = value.get(key) return deepcopy(value) def _resolve_template(template: Any, scope: dict[str, Any]) -> Any: if isinstance(template, dict): if set(template.keys()) == {"from"}: return _resolve_path(scope, str(template["from"])) return {key: _resolve_template(value, scope) for key, value in template.items()} if isinstance(template, list): return [_resolve_template(item, scope) for item in template] return deepcopy(template) def _validate_required_fields( arguments: dict[str, Any], required_fields: list[str], step_name: str, ) -> None: missing_fields: list[str] = [] for field in required_fields: value = arguments.get(field) if isinstance(value, str) and value.strip(): continue if value not in (None, "", [], {}): continue missing_fields.append(field) if missing_fields: fields_str = ", ".join(missing_fields) raise ValueError(f"{step_name}: missing required fields: {fields_str}") def _missing_required_fields(arguments: dict[str, Any], required_fields: list[str]) -> list[str]: missing_fields: list[str] = [] for field in required_fields: value = arguments.get(field) if isinstance(value, str) and value.strip(): continue if value not in (None, "", [], {}): continue missing_fields.append(field) return missing_fields def _build_arguments_schema(required_fields: list[str]) -> dict[str, Any]: properties = {field: {"type": "any"} for field in required_fields} return { "type": "object", "required": required_fields, "properties": properties, } def _build_polza_response_schema(required_fields: list[str]) -> dict[str, Any]: value_schema: dict[str, Any] = { "type": ["string", "number", "boolean", "array", "object", "null"] } arguments_properties = {field: value_schema for field in required_fields} return { "name": "mcp_arguments", "strict": True, "schema": { "type": "object", "properties": { "arguments": { "type": "object", "properties": arguments_properties, "required": required_fields, "additionalProperties": True, } }, "required": ["arguments"], "additionalProperties": False, }, } def _extract_planned_arguments(content: Any) -> dict[str, Any]: candidate: Any = content if isinstance(candidate, str): text = candidate.strip() if text.startswith("```"): text = text.strip("`").strip() if text.startswith("json"): text = text[4:].strip() try: candidate = json.loads(text) except json.JSONDecodeError: return {} if isinstance(candidate, dict): if isinstance(candidate.get("arguments"), dict): return candidate["arguments"] # Some models return the arguments object directly. return candidate return {} class McpWorkflowRunner: """ Minimal workflow runner: - fixed step order from scenario - same planner agent in every step - MCP call executed by code, not by the agent - request/response persisted in run context """ def __init__(self) -> None: self._workflow_cache: dict[str, Workflow] = {} self._planner_repair_attempts = _env_int("PLANNER_REPAIR_ATTEMPTS", 3) self._run_state_ctx: ContextVar[dict[str, Any] | None] = ContextVar( "mcp_workflow_run_state", default=None, ) def _get_run_state(self) -> dict[str, Any]: run_state = self._run_state_ctx.get() if run_state is None: raise RuntimeError("run state is not initialized") return run_state def _build_scope(self) -> dict[str, Any]: run_state = self._get_run_state() return { "input": run_state.get("input", {}), "steps": run_state.get("steps", {}), } async def _plan_arguments( self, *, step_name: str, tool_name: str, base_arguments: dict[str, Any], required_fields: list[str], scope: dict[str, Any], planner_cache: dict[str, dict[str, Any]] | None = None, missing_fields: list[str] | None = None, attempt_no: int = 1, ) -> dict[str, Any]: cache_key: str | None = None if planner_cache is not None: try: cache_payload = { "tool_name": tool_name, "base_arguments": base_arguments, "required_fields": required_fields, "missing_fields": missing_fields or [], "attempt_no": attempt_no, } cache_key = json.dumps(cache_payload, sort_keys=True, ensure_ascii=False) except TypeError: cache_key = None if cache_key is not None and cache_key in planner_cache: return deepcopy(planner_cache[cache_key]) planner_context = { "input": scope.get("input", {}), "steps": scope.get("steps", {}), } for key, value in scope.items(): if key in {"input", "steps"}: continue planner_context[key] = value prompt = { "task": "Prepare MCP arguments for this step.", "step_name": step_name, "tool_name": tool_name, "required_fields": required_fields, "base_arguments": base_arguments, "missing_fields": missing_fields or [], "repair_attempt": attempt_no, "arguments_schema": _build_arguments_schema(required_fields), "context": planner_context, "response_contract": { "must_return": {"arguments": "object"}, "must_include_fields": missing_fields or [], "forbidden": "extra unrelated keys", }, "output": ( "Return only JSON object with key 'arguments'. " "If missing_fields is not empty, fill every missing field from context." ), } prompt_json = json.dumps(prompt, ensure_ascii=False) planned: dict[str, Any] = {} # Primary path: strict structured output via Polza response_format/json_schema. try: completion = await get_shared_step_planner_client().chat.completions.create( model=os.getenv("POLZA_MODEL_ID", "google/gemma-4-31b-it"), messages=[ { "role": "system", "content": ( "You are a tool-input planner. " "Return only JSON that matches the provided schema." ), }, {"role": "user", "content": prompt_json}, ], response_format={ "type": "json_schema", "json_schema": _build_polza_response_schema(required_fields), }, temperature=_env_float("POLZA_TEMPERATURE", 0.0), ) raw_content = completion.choices[0].message.content if completion.choices else "" planned = _extract_planned_arguments(raw_content) except Exception: planned = {} if not isinstance(planned, dict): planned = {} # Allow planner to override/fill base arguments while keeping known defaults. merged = deepcopy(base_arguments) merged.update(planned) if planner_cache is not None and cache_key is not None: planner_cache[cache_key] = deepcopy(merged) return merged def _build_tool_step_executor(self, step_spec: dict[str, Any]): step_name = str(step_spec["name"]) tool_name = str(step_spec["tool"]) input_template = step_spec.get("input", {}) foreach_spec = step_spec.get("foreach") collect_template = step_spec.get("collect") collect_key = str(step_spec.get("collect_key", "items")).strip() or "items" required_fields_raw = step_spec.get("required_input_fields", []) required_fields = ( [field for field in required_fields_raw if isinstance(field, str)] if isinstance(required_fields_raw, list) else [] ) if isinstance(foreach_spec, dict): source_path = str(foreach_spec.get("from", "")).strip() item_alias = str(foreach_spec.get("as", "item")).strip() or "item" else: source_path = str(foreach_spec).strip() if isinstance(foreach_spec, str) else "" item_alias = "item" async def _executor(_step_input: StepInput) -> StepOutput: run_state = self._get_run_state() scope = self._build_scope() step_started_at = _utc_now_iso() planner_cache: dict[str, dict[str, Any]] = {} async def _prepare_arguments( *, local_scope: dict[str, Any], local_base_arguments: dict[str, Any], ) -> dict[str, Any]: final_arguments = deepcopy(local_base_arguments) for repair_attempt in range(1, self._planner_repair_attempts + 1): missing_fields = _missing_required_fields(final_arguments, required_fields) if not missing_fields: break final_arguments = await self._plan_arguments( step_name=step_name, tool_name=tool_name, base_arguments=final_arguments, required_fields=required_fields, scope=local_scope, planner_cache=planner_cache, missing_fields=missing_fields, attempt_no=repair_attempt, ) _validate_required_fields(final_arguments, required_fields, step_name) return final_arguments async def _call_tool_with_repair( *, initial_arguments: dict[str, Any], ) -> tuple[dict[str, Any], dict[str, Any]]: final_arguments = deepcopy(initial_arguments) tool_response = await call_mcp_tool(tool_name, final_arguments) return tool_response, final_arguments try: tool_calls = run_state.setdefault("tool_calls", []) if not isinstance(tool_calls, list): tool_calls = [] run_state["tool_calls"] = tool_calls if source_path: iterable = _resolve_path(scope, source_path) if not isinstance(iterable, list): raise ValueError(f"{step_name}: foreach source is not list") collected_items: list[Any] = [] for index, item in enumerate(iterable): iteration_scope = dict(scope) iteration_scope[item_alias] = item iteration_scope["item"] = item iteration_scope["index"] = index resolved = _resolve_template(input_template, iteration_scope) base_arguments = resolved if isinstance(resolved, dict) else {} final_arguments = await _prepare_arguments( local_scope=iteration_scope, local_base_arguments=base_arguments, ) tool_response, final_arguments = await _call_tool_with_repair( initial_arguments=final_arguments, ) tool_calls.append( { "step_name": step_name, "tool_name": tool_name, "attempt": index + 1, "request": final_arguments, "ok": True, "response": tool_response, } ) if collect_template is None: collected_items.append(tool_response.get("payload", {})) else: collected_items.append( _resolve_template( collect_template, {**iteration_scope, "tool": tool_response}, ) ) step_payload = { "ok": True, "tool_name": tool_name, "payload": {collect_key: collected_items}, "request": {"foreach_from": source_path, "count": len(iterable)}, "received_at": _utc_now_iso(), "started_at": step_started_at, "finished_at": _utc_now_iso(), } else: resolved = _resolve_template(input_template, scope) base_arguments = resolved if isinstance(resolved, dict) else {} final_arguments = await _prepare_arguments( local_scope=scope, local_base_arguments=base_arguments, ) tool_response, final_arguments = await _call_tool_with_repair( initial_arguments=final_arguments, ) step_payload = { "ok": bool(tool_response.get("ok", True)), "tool_name": tool_name, "payload": tool_response.get("payload", {}), "request": final_arguments, "response": tool_response, "received_at": tool_response.get("received_at"), "started_at": step_started_at, "finished_at": _utc_now_iso(), } tool_calls.append( { "step_name": step_name, "tool_name": tool_name, "request": final_arguments, "ok": True, "response": tool_response, } ) run_state.setdefault("steps", {})[step_name] = step_payload return StepOutput( content=json.dumps(step_payload, ensure_ascii=False), success=True, ) except Exception as exc: error_payload = { "ok": False, "tool_name": tool_name, "request": {}, "error": str(exc), "started_at": step_started_at, "finished_at": _utc_now_iso(), } run_state.setdefault("steps", {})[step_name] = error_payload run_state.setdefault("tool_calls", []).append( { "step_name": step_name, "tool_name": tool_name, "request": {}, "ok": False, "error": str(exc), } ) raise RuntimeError(f"{step_name} failed: {exc}") from exc return _executor def get_workflow(self, scenario_id: str, scenario: dict[str, Any]) -> Workflow: cached = self._workflow_cache.get(scenario_id) if cached is not None: return cached raw_steps = scenario.get("steps") if not isinstance(raw_steps, list) or not raw_steps: raise ScenarioStoreError("Scenario must contain non-empty steps list") workflow_steps: list[Step] = [] for raw_step in raw_steps: if not isinstance(raw_step, dict): raise ScenarioStoreError("Each scenario step must be object") if raw_step.get("type") != "tool": raise ScenarioStoreError("This minimal runner supports only tool steps") step_name = str(raw_step.get("name", "")).strip() tool_name = str(raw_step.get("tool", step_name)).strip() if not step_name or not tool_name: raise ScenarioStoreError("Each tool step must contain non-empty name and tool") executor = self._build_tool_step_executor(raw_step) workflow_steps.append( Step( name=step_name, description=str(raw_step.get("description", step_name)), executor=executor, max_retries=0, on_error="fail", ) ) workflow = Workflow( name=scenario_id, description=str(scenario.get("description", "")), steps=workflow_steps, ) self._workflow_cache[scenario_id] = workflow return workflow async def run(self, *, scenario_id: str, input_data: dict[str, Any]) -> dict[str, Any]: scenario = load_scenario_definition(scenario_id) workflow = self.get_workflow(scenario_id, scenario) initial_state = { "input": deepcopy(input_data), "steps": {}, "tool_calls": [], } token = self._run_state_ctx.set(initial_state) run_state = initial_state run_output: Any = None workflow_error: str | None = None try: run_output = await workflow.arun(input=input_data) except Exception as exc: workflow_error = str(exc) finally: captured = self._run_state_ctx.get() if isinstance(captured, dict): run_state = deepcopy(captured) self._run_state_ctx.reset(token) content = run_output.content if hasattr(run_output, "content") else None if isinstance(content, str): try: content = json.loads(content) except json.JSONDecodeError: content = {"raw_content": content} if content is None: step_payloads = run_state.get("steps", {}) if isinstance(step_payloads, dict): for payload in reversed(list(step_payloads.values())): if isinstance(payload, dict) and not bool(payload.get("ok", True)): content = deepcopy(payload) break if content is None and workflow_error is not None: content = {"error": workflow_error} status = "success" if workflow_error is not None: status = "failed" elif run_output is not None and not bool(getattr(run_output, "success", True)): status = "failed" return { "scenario_id": scenario_id, "workflow_name": workflow.name, "status": status, "input": input_data, "final_result": content if isinstance(content, dict) else {"raw_content": content}, "steps": run_state.get("steps", {}), "tool_calls": run_state.get("tool_calls", []), "run_id": str(getattr(run_output, "run_id", "")) or None, "session_id": str(getattr(run_output, "session_id", "")) or None, } _default_runner: McpWorkflowRunner | None = None def get_mcp_workflow_runner() -> McpWorkflowRunner: global _default_runner if _default_runner is not None: return _default_runner _default_runner = McpWorkflowRunner() return _default_runner def _extract_output_summary(content: Any) -> str | None: if not isinstance(content, dict): return None summary = content.get("summary") if isinstance(summary, str) and summary: return summary payload = content.get("payload") if isinstance(payload, dict): payload_summary = payload.get("summary") if isinstance(payload_summary, str) and payload_summary: return payload_summary return None def _build_step_states_from_minimal( *, scenario: dict[str, Any], minimal_steps: dict[str, Any], ) -> list[StepState]: raw_steps = scenario.get("steps") if not isinstance(raw_steps, list): return [] step_states: list[StepState] = [] for raw_step in raw_steps: if not isinstance(raw_step, dict): continue step_name = str(raw_step.get("name", "")).strip() if not step_name: continue payload = minimal_steps.get(step_name) if not isinstance(payload, dict): step_states.append(StepState(node_id=step_name, status="queued")) continue ok = bool(payload.get("ok", False)) step_states.append( StepState( node_id=step_name, status="success" if ok else "failed", started_at=str(payload.get("started_at") or "") or None, finished_at=str(payload.get("finished_at") or "") or None, error=RunError( code="tool_error", message=str(payload.get("error", f"{step_name} failed")), ) if not ok else None, ) ) return step_states async def run_scenario_workflow( input_data: dict[str, Any], scenario_id: str = "news_source_discovery_v1", ) -> dict[str, Any]: try: scenario = load_scenario_definition(scenario_id) except ScenarioStoreError as exc: return ScenarioRunResponse( scenario_id=scenario_id, status="failed", input=input_data, steps=[], error=RunError(code="unknown_scenario", message=str(exc)), ).model_dump() runner = get_mcp_workflow_runner() scenario_name = str(scenario.get("name", scenario_id)) try: minimal_result = await runner.run( scenario_id=scenario_id, input_data=input_data, ) except Exception as exc: return ScenarioRunResponse( scenario_id=scenario_id, status="failed", input=input_data, scenario_name=scenario_name, steps=[], error=RunError(code="workflow_error", message=str(exc)), ).model_dump() minimal_steps = minimal_result.get("steps", {}) steps = ( minimal_steps if isinstance(minimal_steps, dict) else {} ) step_states = _build_step_states_from_minimal( scenario=scenario, minimal_steps=steps, ) final_result = minimal_result.get("final_result") normalized_result = ( final_result if isinstance(final_result, dict) else {"raw_content": str(final_result)} ) status = "success" for payload in steps.values(): if isinstance(payload, dict) and not bool(payload.get("ok", False)): status = "failed" break return ScenarioRunResponse( scenario_id=scenario_id, status=status, input=input_data, steps=step_states, output_summary=_extract_output_summary(normalized_result), scenario_name=scenario_name, workflow_name=str(minimal_result.get("workflow_name") or scenario_id), result=normalized_result, error=None if status == "success" else RunError(code="workflow_failed", message="Workflow finished with failed status."), run_id=minimal_result.get("run_id"), session_id=minimal_result.get("session_id"), ).model_dump()