"""Runs endpoints — create, stream, wait, cancel. Implements the LangGraph Platform runs API on top of :class:`deerflow.agents.runs.RunManager` and :class:`deerflow.agents.stream_bridge.StreamBridge`. SSE format is aligned with the LangGraph Platform protocol so that the ``useStream`` React hook from ``@langchain/langgraph-sdk/react`` works without modification. """ from __future__ import annotations import asyncio import logging from typing import Any, Literal from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import Response, StreamingResponse from pydantic import BaseModel, Field from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge from app.gateway.services import sse_consumer, start_run from deerflow.runtime import RunRecord, serialize_channel_values logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/threads", tags=["runs"]) # --------------------------------------------------------------------------- # Request / response models # --------------------------------------------------------------------------- class RunCreateRequest(BaseModel): assistant_id: str | None = Field(default=None, description="Agent / assistant to use") input: dict[str, Any] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})") command: dict[str, Any] | None = Field(default=None, description="LangGraph Command") metadata: dict[str, Any] | None = Field(default=None, description="Run metadata") config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides") context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)") webhook: str | None = Field(default=None, description="Completion callback URL") checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint") checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object") interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before") interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after") stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)") stream_subgraphs: bool = Field(default=False, description="Include subgraph events") stream_resumable: bool | None = Field(default=None, description="SSE resumable mode") on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect") on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion") multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy") after_seconds: float | None = Field(default=None, description="Delayed execution") if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy") feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys") class RunResponse(BaseModel): run_id: str thread_id: str assistant_id: str | None = None status: str metadata: dict[str, Any] = Field(default_factory=dict) kwargs: dict[str, Any] = Field(default_factory=dict) multitask_strategy: str = "reject" created_at: str = "" updated_at: str = "" # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _record_to_response(record: RunRecord) -> RunResponse: return RunResponse( run_id=record.run_id, thread_id=record.thread_id, assistant_id=record.assistant_id, status=record.status.value, metadata=record.metadata, kwargs=record.kwargs, multitask_strategy=record.multitask_strategy, created_at=record.created_at, updated_at=record.updated_at, ) # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @router.post("/{thread_id}/runs", response_model=RunResponse) async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse: """Create a background run (returns immediately).""" record = await start_run(body, thread_id, request) return _record_to_response(record) @router.post("/{thread_id}/runs/stream") async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse: """Create a run and stream events via SSE. The response includes a ``Content-Location`` header with the run's resource URL, matching the LangGraph Platform protocol. The ``useStream`` React hook uses this to extract run metadata. """ bridge = get_stream_bridge(request) run_mgr = get_run_manager(request) record = await start_run(body, thread_id, request) return StreamingResponse( sse_consumer(bridge, record, request, run_mgr), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", # LangGraph Platform includes run metadata in this header. # The SDK uses a greedy regex to extract the run id from this path, # so it must point at the canonical run resource without extra suffixes. "Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}", }, ) @router.post("/{thread_id}/runs/wait", response_model=dict) async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict: """Create a run and block until it completes, returning the final state.""" record = await start_run(body, thread_id, request) if record.task is not None: try: await record.task except asyncio.CancelledError: pass checkpointer = get_checkpointer(request) config = {"configurable": {"thread_id": thread_id}} try: checkpoint_tuple = await checkpointer.aget_tuple(config) if checkpoint_tuple is not None: checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} channel_values = checkpoint.get("channel_values", {}) return serialize_channel_values(channel_values) except Exception: logger.exception("Failed to fetch final state for run %s", record.run_id) return {"status": record.status.value, "error": record.error} @router.get("/{thread_id}/runs", response_model=list[RunResponse]) async def list_runs(thread_id: str, request: Request) -> list[RunResponse]: """List all runs for a thread.""" run_mgr = get_run_manager(request) records = await run_mgr.list_by_thread(thread_id) return [_record_to_response(r) for r in records] @router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse) async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse: """Get details of a specific run.""" run_mgr = get_run_manager(request) record = run_mgr.get(run_id) if record is None or record.thread_id != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") return _record_to_response(record) @router.post("/{thread_id}/runs/{run_id}/cancel") async def cancel_run( thread_id: str, run_id: str, request: Request, wait: bool = Query(default=False, description="Block until run completes after cancel"), action: Literal["interrupt", "rollback"] = Query(default="interrupt", description="Cancel action"), ) -> Response: """Cancel a running or pending run. - action=interrupt: Stop execution, keep current checkpoint (can be resumed) - action=rollback: Stop execution, revert to pre-run checkpoint state - wait=true: Block until the run fully stops, return 204 - wait=false: Return immediately with 202 """ run_mgr = get_run_manager(request) record = run_mgr.get(run_id) if record is None or record.thread_id != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") cancelled = await run_mgr.cancel(run_id, action=action) if not cancelled: raise HTTPException( status_code=409, detail=f"Run {run_id} is not cancellable (status: {record.status.value})", ) if wait and record.task is not None: try: await record.task except asyncio.CancelledError: pass return Response(status_code=204) return Response(status_code=202) @router.get("/{thread_id}/runs/{run_id}/join") async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse: """Join an existing run's SSE stream.""" bridge = get_stream_bridge(request) run_mgr = get_run_manager(request) record = run_mgr.get(run_id) if record is None or record.thread_id != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") return StreamingResponse( sse_consumer(bridge, record, request, run_mgr), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None) async def stream_existing_run( thread_id: str, run_id: str, request: Request, action: Literal["interrupt", "rollback"] | None = Query(default=None, description="Cancel action"), wait: int = Query(default=0, description="Block until cancelled (1) or return immediately (0)"), ): """Join an existing run's SSE stream (GET), or cancel-then-stream (POST). The LangGraph SDK's ``joinStream`` and ``useStream`` stop button both use ``POST`` to this endpoint. When ``action=interrupt`` or ``action=rollback`` is present the run is cancelled first; the response then streams any remaining buffered events so the client observes a clean shutdown. """ run_mgr = get_run_manager(request) record = run_mgr.get(run_id) if record is None or record.thread_id != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") # Cancel if an action was requested (stop-button / interrupt flow) if action is not None: cancelled = await run_mgr.cancel(run_id, action=action) if cancelled and wait and record.task is not None: try: await record.task except (asyncio.CancelledError, Exception): pass return Response(status_code=204) bridge = get_stream_bridge(request) return StreamingResponse( sse_consumer(bridge, record, request, run_mgr), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, )