"""REST routes for scenario execution, catalogs and live run events. Runs are executed asynchronously: ``POST /api/runs`` schedules a background task and returns immediately with a ``run_id``. Clients consume progress via ``GET /api/runs/{run_id}/events`` (SSE) or poll ``GET /api/runs/{run_id}`` for a snapshot. """ from __future__ import annotations import asyncio import json from typing import Any, AsyncIterator from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from loguru import logger from src.mcp_client import list_mcp_tools from src.mcp_workflow_runner import run_scenario_async from src.run_registry import RunRecord, get_registry from src.scenario_store import ( ScenarioStoreError, list_scenario_summaries, load_scenario_definition, ) from src.schemas import ( RunSubmitResponse, ScenarioRunRequest, ScenarioRunResponse, ScenarioSummary, StepState, ToolSummary, ) router = APIRouter(prefix="/api", tags=["workflow"]) # --------------------------------------------------------------------------- # Runs # --------------------------------------------------------------------------- @router.post( "/runs", response_model=RunSubmitResponse, status_code=202, summary="Schedule a scenario run", description=( "Creates a run record and schedules execution in the background. " "Returns immediately with a `run_id`; poll `GET /api/runs/{run_id}` " "or subscribe to `GET /api/runs/{run_id}/events` for progress." ), ) async def post_run(request: ScenarioRunRequest) -> RunSubmitResponse: registry = get_registry() record = registry.create( scenario_id=request.scenario_id, input_data=request.input, ) record.task = asyncio.create_task(run_scenario_async(record)) return RunSubmitResponse( run_id=record.run_id, scenario_id=record.scenario_id, status=record.status, input=record.input, started_at=record.started_at, ) @router.get( "/runs/{run_id}", response_model=ScenarioRunResponse, summary="Get run snapshot", description=( "Returns the current state of a run. For running runs the `steps` " "list reflects progress so far; for terminal runs it is complete." ), responses={404: {"description": "Unknown run_id"}}, ) async def get_run(run_id: str) -> ScenarioRunResponse: record = _require_run(run_id) if record.response is not None: return record.response return _snapshot_from_record(record) @router.get( "/runs/{run_id}/events", summary="Live run progress (SSE)", description=( "Server-Sent Events stream. Late subscribers receive a replay of " "buffered events first, then tail new events until `run_finished`.\n\n" "Event types: `run_started`, `step_started`, `step_finished`, " "`run_finished`. Each event is JSON in the SSE `data:` field." ), responses={ 200: { "description": "SSE stream of run events", "content": { "text/event-stream": { "example": ( "event: run_started\n" 'data: {"type":"run_started","run_id":"76d6903c-f520-4a40-b0fc-8fed3f7955d2",' '"scenario_id":"news_source_discovery_v1","started_at":"2026-04-24T09:27:59.873+00:00"}\n\n' "event: step_started\n" 'data: {"type":"step_started","run_id":"76d6903c-...","step_name":"search_news_sources",' '"index":0,"started_at":"2026-04-24T09:27:59.875+00:00"}\n\n' "event: step_finished\n" 'data: {"type":"step_finished","run_id":"76d6903c-...","step_name":"search_news_sources",' '"index":0,"status":"success","started_at":"2026-04-24T09:27:59.875+00:00",' '"finished_at":"2026-04-24T09:28:00.028+00:00","message":""}\n\n' "event: run_finished\n" 'data: {"type":"run_finished","run_id":"76d6903c-...","status":"success",' '"finished_at":"2026-04-24T09:28:01.750+00:00","message":""}\n\n' ) } }, }, 404: {"description": "Unknown run_id"}, }, ) async def get_run_events(run_id: str) -> StreamingResponse: record = _require_run(run_id) return StreamingResponse( _event_stream(record), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", }, ) # --------------------------------------------------------------------------- # Scenario catalog # --------------------------------------------------------------------------- @router.get( "/scenarios", response_model=list[ScenarioSummary], summary="List available scenarios", description="Returns metadata (id, name, description, input schema) for every scenario in the index.", ) async def get_scenarios() -> list[ScenarioSummary]: return [ScenarioSummary(**s) for s in list_scenario_summaries()] _SCENARIO_DEFINITION_EXAMPLE: dict[str, Any] = { "schema_version": "1", "scenario_id": "news_source_discovery_v1", "name": "News Source Discovery V1", "description": "Find earliest news source using sequential MCP tools.", "input_schema": { "type": "object", "required": ["url"], "properties": { "url": {"type": "string", "description": "URL of source news article"} }, }, "steps": [ { "name": "search_news_sources", "type": "tool", "tool": "search_news_sources", "input": {"url": {"from": "input.url"}}, "required_input_fields": ["url"], } ], } @router.get( "/scenarios/{scenario_id}", summary="Get full scenario definition", description="Returns the raw scenario JSON (including the `steps` graph) for UI visualization.", responses={ 200: { "description": "Scenario definition", "content": {"application/json": {"example": _SCENARIO_DEFINITION_EXAMPLE}}, }, 404: {"description": "Unknown scenario_id"}, }, ) async def get_scenario(scenario_id: str) -> dict[str, Any]: try: return load_scenario_definition(scenario_id) except ScenarioStoreError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc # --------------------------------------------------------------------------- # Tool catalog # --------------------------------------------------------------------------- @router.get( "/tools", response_model=list[ToolSummary], summary="List MCP tools", description="Proxies MCP `list_tools()` and returns name, description, and input schema for each tool.", responses={502: {"description": "MCP transport error"}}, ) async def get_tools() -> list[ToolSummary]: try: tools = await list_mcp_tools() except RuntimeError as exc: logger.warning("Failed to fetch MCP tools: {}", exc) raise HTTPException(status_code=502, detail=str(exc)) from exc return [ToolSummary(**t) for t in tools] # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _require_run(run_id: str) -> RunRecord: record = get_registry().get(run_id) if record is None: raise HTTPException(status_code=404, detail=f"Unknown run_id: {run_id}") return record def _snapshot_from_record(record: RunRecord) -> ScenarioRunResponse: """Build a partial ScenarioRunResponse for a still-running or pre-start run.""" steps: list[StepState] = [] for event in record.events: if event.get("type") != "step_finished": continue steps.append( StepState( node_id=str(event.get("step_name", "")), status=event.get("status", "failed"), started_at=event.get("started_at"), finished_at=event.get("finished_at"), message=str(event.get("message", "")), ) ) return ScenarioRunResponse( scenario_id=record.scenario_id, status=record.status, message=record.message, input=record.input, steps=steps, run_id=record.run_id, ) async def _event_stream(record: RunRecord) -> AsyncIterator[bytes]: """Replay buffered events, then tail a fresh subscriber queue. The snapshot/subscribe pair runs without any intervening ``await``, so no emitted event can slip between the replay cutoff and the subscription. Events emitted during replay land in the queue and are drained afterwards. """ queue: asyncio.Queue = asyncio.Queue() buffered = list(record.events) record.subscribers.append(queue) try: for event in buffered: yield _format_sse(event) if record.is_terminal(): return while True: event = await queue.get() if event is None: return yield _format_sse(event) finally: if queue in record.subscribers: record.subscribers.remove(queue) def _format_sse(event: dict[str, Any]) -> bytes: event_type = str(event.get("type", "message")) payload = json.dumps(event, ensure_ascii=False) return f"event: {event_type}\ndata: {payload}\n\n".encode("utf-8")