Initial commit: hardened DeerFlow factory

Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection
hardening:

- New deerflow.security package: content_delimiter, html_cleaner,
  sanitizer (8 layers — invisible chars, control chars, symbols, NFC,
  PUA, tag chars, horizontal whitespace collapse with newline/tab
  preservation, length cap)
- New deerflow.community.searx package: web_search, web_fetch,
  image_search backed by a private SearX instance, every external
  string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>>
  delimiters
- All native community web providers (ddg_search, tavily, exa,
  firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail
  stubs that raise NativeWebToolDisabledError at import time, so a
  misconfigured tool.use path fails loud rather than silently falling
  back to unsanitized output
- Native client back-doors (jina_client.py, infoquest_client.py)
  stubbed too
- Native-tool tests quarantined under tests/_disabled_native/
  (collect_ignore_glob via local conftest.py)
- Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve
  newlines and tabs so list/table structure survives
- Hardened runtime config.yaml references only the searx-backed tools
- Factory overlay (backend/) kept in sync with deer-flow tree as a
  reference / source

See HARDENING.md for the full audit trail and verification steps.
This commit is contained in:
2026-04-12 14:23:57 +02:00
commit 6de0bf9f5b
889 changed files with 173052 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
from .app import app, create_app
from .config import GatewayConfig, get_gateway_config
__all__ = ["app", "create_app", "GatewayConfig", "get_gateway_config"]

View File

@@ -0,0 +1,221 @@
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI
from app.gateway.config import get_gateway_config
from app.gateway.deps import langgraph_runtime
from app.gateway.routers import (
agents,
artifacts,
assistants_compat,
channels,
mcp,
memory,
models,
runs,
skills,
suggestions,
thread_runs,
threads,
uploads,
)
from deerflow.config.app_config import get_app_config
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan handler."""
# Load config and check necessary environment variables at startup
try:
get_app_config()
logger.info("Configuration loaded successfully")
except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}"
logger.exception(error_msg)
raise RuntimeError(error_msg) from e
config = get_gateway_config()
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
async with langgraph_runtime(app):
logger.info("LangGraph runtime initialised")
# Start IM channel service if any channels are configured
try:
from app.channels.service import start_channel_service
channel_service = await start_channel_service()
logger.info("Channel service started: %s", channel_service.get_status())
except Exception:
logger.exception("No IM channels configured or channel service failed to start")
yield
# Stop channel service on shutdown
try:
from app.channels.service import stop_channel_service
await stop_channel_service()
except Exception:
logger.exception("Failed to stop channel service")
logger.info("Shutting down API Gateway")
def create_app() -> FastAPI:
"""Create and configure the FastAPI application.
Returns:
Configured FastAPI application instance.
"""
app = FastAPI(
title="DeerFlow API Gateway",
description="""
## DeerFlow API Gateway
API Gateway for DeerFlow - A LangGraph-based AI agent backend with sandbox execution capabilities.
### Features
- **Models Management**: Query and retrieve available AI models
- **MCP Configuration**: Manage Model Context Protocol (MCP) server configurations
- **Memory Management**: Access and manage global memory data for personalized conversations
- **Skills Management**: Query and manage skills and their enabled status
- **Artifacts**: Access thread artifacts and generated files
- **Health Monitoring**: System health check endpoints
### Architecture
LangGraph requests are handled by nginx reverse proxy.
This gateway provides custom endpoints for models, MCP configuration, skills, and artifacts.
""",
version="0.1.0",
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
openapi_tags=[
{
"name": "models",
"description": "Operations for querying available AI models and their configurations",
},
{
"name": "mcp",
"description": "Manage Model Context Protocol (MCP) server configurations",
},
{
"name": "memory",
"description": "Access and manage global memory data for personalized conversations",
},
{
"name": "skills",
"description": "Manage skills and their configurations",
},
{
"name": "artifacts",
"description": "Access and download thread artifacts and generated files",
},
{
"name": "uploads",
"description": "Upload and manage user files for threads",
},
{
"name": "threads",
"description": "Manage DeerFlow thread-local filesystem data",
},
{
"name": "agents",
"description": "Create and manage custom agents with per-agent config and prompts",
},
{
"name": "suggestions",
"description": "Generate follow-up question suggestions for conversations",
},
{
"name": "channels",
"description": "Manage IM channel integrations (Feishu, Slack, Telegram)",
},
{
"name": "assistants-compat",
"description": "LangGraph Platform-compatible assistants API (stub)",
},
{
"name": "runs",
"description": "LangGraph Platform-compatible runs lifecycle (create, stream, cancel)",
},
{
"name": "health",
"description": "Health check and system status endpoints",
},
],
)
# CORS is handled by nginx - no need for FastAPI middleware
# Include routers
# Models API is mounted at /api/models
app.include_router(models.router)
# MCP API is mounted at /api/mcp
app.include_router(mcp.router)
# Memory API is mounted at /api/memory
app.include_router(memory.router)
# Skills API is mounted at /api/skills
app.include_router(skills.router)
# Artifacts API is mounted at /api/threads/{thread_id}/artifacts
app.include_router(artifacts.router)
# Uploads API is mounted at /api/threads/{thread_id}/uploads
app.include_router(uploads.router)
# Thread cleanup API is mounted at /api/threads/{thread_id}
app.include_router(threads.router)
# Agents API is mounted at /api/agents
app.include_router(agents.router)
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
app.include_router(suggestions.router)
# Channels API is mounted at /api/channels
app.include_router(channels.router)
# Assistants compatibility API (LangGraph Platform stub)
app.include_router(assistants_compat.router)
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
app.include_router(thread_runs.router)
# Stateless Runs API (stream/wait without a pre-existing thread)
app.include_router(runs.router)
@app.get("/health", tags=["health"])
async def health_check() -> dict:
"""Health check endpoint.
Returns:
Service health status information.
"""
return {"status": "healthy", "service": "deer-flow-gateway"}
return app
# Create app instance for uvicorn
app = create_app()

View File

@@ -0,0 +1,27 @@
import os
from pydantic import BaseModel, Field
class GatewayConfig(BaseModel):
"""Configuration for the API Gateway."""
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
port: int = Field(default=8001, description="Port to bind the gateway server")
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins")
_gateway_config: GatewayConfig | None = None
def get_gateway_config() -> GatewayConfig:
"""Get gateway config, loading from environment if available."""
global _gateway_config
if _gateway_config is None:
cors_origins_str = os.getenv("CORS_ORIGINS", "http://localhost:3000")
_gateway_config = GatewayConfig(
host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
port=int(os.getenv("GATEWAY_PORT", "8001")),
cors_origins=cors_origins_str.split(","),
)
return _gateway_config

View File

@@ -0,0 +1,70 @@
"""Centralized accessors for singleton objects stored on ``app.state``.
**Getters** (used by routers): raise 503 when a required dependency is
missing, except ``get_store`` which returns ``None``.
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import AsyncExitStack, asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from deerflow.runtime import RunManager, StreamBridge
@asynccontextmanager
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
"""Bootstrap and tear down all LangGraph runtime singletons.
Usage in ``app.py``::
async with langgraph_runtime(app):
yield
"""
from deerflow.agents.checkpointer.async_provider import make_checkpointer
from deerflow.runtime import make_store, make_stream_bridge
async with AsyncExitStack() as stack:
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
app.state.store = await stack.enter_async_context(make_store())
app.state.run_manager = RunManager()
yield
# ---------------------------------------------------------------------------
# Getters called by routers per-request
# ---------------------------------------------------------------------------
def get_stream_bridge(request: Request) -> StreamBridge:
"""Return the global :class:`StreamBridge`, or 503."""
bridge = getattr(request.app.state, "stream_bridge", None)
if bridge is None:
raise HTTPException(status_code=503, detail="Stream bridge not available")
return bridge
def get_run_manager(request: Request) -> RunManager:
"""Return the global :class:`RunManager`, or 503."""
mgr = getattr(request.app.state, "run_manager", None)
if mgr is None:
raise HTTPException(status_code=503, detail="Run manager not available")
return mgr
def get_checkpointer(request: Request):
"""Return the global checkpointer, or 503."""
cp = getattr(request.app.state, "checkpointer", None)
if cp is None:
raise HTTPException(status_code=503, detail="Checkpointer not available")
return cp
def get_store(request: Request):
"""Return the global store (may be ``None`` if not configured)."""
return getattr(request.app.state, "store", None)

View File

@@ -0,0 +1,28 @@
"""Shared path resolution for thread virtual paths (e.g. mnt/user-data/outputs/...)."""
from pathlib import Path
from fastapi import HTTPException
from deerflow.config.paths import get_paths
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
"""Resolve a virtual path to the actual filesystem path under thread user-data.
Args:
thread_id: The thread ID.
virtual_path: The virtual path as seen inside the sandbox
(e.g., /mnt/user-data/outputs/file.txt).
Returns:
The resolved filesystem path.
Raises:
HTTPException: If the path is invalid or outside allowed directories.
"""
try:
return get_paths().resolve_virtual_path(thread_id, virtual_path)
except ValueError as e:
status = 403 if "traversal" in str(e) else 400
raise HTTPException(status_code=status, detail=str(e))

View File

@@ -0,0 +1,3 @@
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]

View File

@@ -0,0 +1,383 @@
"""CRUD API for custom agents."""
import logging
import re
import shutil
import yaml
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["agents"])
AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
class AgentResponse(BaseModel):
"""Response model for a custom agent."""
name: str = Field(..., description="Agent name (hyphen-case)")
description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
soul: str | None = Field(default=None, description="SOUL.md content")
class AgentsListResponse(BaseModel):
"""Response model for listing all custom agents."""
agents: list[AgentResponse]
class AgentCreateRequest(BaseModel):
"""Request body for creating a custom agent."""
name: str = Field(..., description="Agent name (must match ^[A-Za-z0-9-]+$, stored as lowercase)")
description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
soul: str = Field(default="", description="SOUL.md content — agent personality and behavioral guardrails")
class AgentUpdateRequest(BaseModel):
"""Request body for updating a custom agent."""
description: str | None = Field(default=None, description="Updated description")
model: str | None = Field(default=None, description="Updated model override")
tool_groups: list[str] | None = Field(default=None, description="Updated tool group whitelist")
soul: str | None = Field(default=None, description="Updated SOUL.md content")
def _validate_agent_name(name: str) -> None:
"""Validate agent name against allowed pattern.
Args:
name: The agent name to validate.
Raises:
HTTPException: 422 if the name is invalid.
"""
if not AGENT_NAME_PATTERN.match(name):
raise HTTPException(
status_code=422,
detail=f"Invalid agent name '{name}'. Must match ^[A-Za-z0-9-]+$ (letters, digits, and hyphens only).",
)
def _normalize_agent_name(name: str) -> str:
"""Normalize agent name to lowercase for filesystem storage."""
return name.lower()
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
"""Convert AgentConfig to AgentResponse."""
soul: str | None = None
if include_soul:
soul = load_agent_soul(agent_cfg.name) or ""
return AgentResponse(
name=agent_cfg.name,
description=agent_cfg.description,
model=agent_cfg.model,
tool_groups=agent_cfg.tool_groups,
soul=soul,
)
@router.get(
"/agents",
response_model=AgentsListResponse,
summary="List Custom Agents",
description="List all custom agents available in the agents directory, including their soul content.",
)
async def list_agents() -> AgentsListResponse:
"""List all custom agents.
Returns:
List of all custom agents with their metadata and soul content.
"""
try:
agents = list_custom_agents()
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
except Exception as e:
logger.error(f"Failed to list agents: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
@router.get(
"/agents/check",
summary="Check Agent Name",
description="Validate an agent name and check if it is available (case-insensitive).",
)
async def check_agent_name(name: str) -> dict:
"""Check whether an agent name is valid and not yet taken.
Args:
name: The agent name to check.
Returns:
``{"available": true/false, "name": "<normalized>"}``
Raises:
HTTPException: 422 if the name is invalid.
"""
_validate_agent_name(name)
normalized = _normalize_agent_name(name)
available = not get_paths().agent_dir(normalized).exists()
return {"available": available, "name": normalized}
@router.get(
"/agents/{name}",
response_model=AgentResponse,
summary="Get Custom Agent",
description="Retrieve details and SOUL.md content for a specific custom agent.",
)
async def get_agent(name: str) -> AgentResponse:
"""Get a specific custom agent by name.
Args:
name: The agent name.
Returns:
Agent details including SOUL.md content.
Raises:
HTTPException: 404 if agent not found.
"""
_validate_agent_name(name)
name = _normalize_agent_name(name)
try:
agent_cfg = load_agent_config(name)
return _agent_config_to_response(agent_cfg, include_soul=True)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
except Exception as e:
logger.error(f"Failed to get agent '{name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get agent: {str(e)}")
@router.post(
"/agents",
response_model=AgentResponse,
status_code=201,
summary="Create Custom Agent",
description="Create a new custom agent with its config and SOUL.md.",
)
async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
"""Create a new custom agent.
Args:
request: The agent creation request.
Returns:
The created agent details.
Raises:
HTTPException: 409 if agent already exists, 422 if name is invalid.
"""
_validate_agent_name(request.name)
normalized_name = _normalize_agent_name(request.name)
agent_dir = get_paths().agent_dir(normalized_name)
if agent_dir.exists():
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
try:
agent_dir.mkdir(parents=True, exist_ok=True)
# Write config.yaml
config_data: dict = {"name": normalized_name}
if request.description:
config_data["description"] = request.description
if request.model is not None:
config_data["model"] = request.model
if request.tool_groups is not None:
config_data["tool_groups"] = request.tool_groups
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True)
# Write SOUL.md
soul_file = agent_dir / "SOUL.md"
soul_file.write_text(request.soul, encoding="utf-8")
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
agent_cfg = load_agent_config(normalized_name)
return _agent_config_to_response(agent_cfg, include_soul=True)
except HTTPException:
raise
except Exception as e:
# Clean up on failure
if agent_dir.exists():
shutil.rmtree(agent_dir)
logger.error(f"Failed to create agent '{request.name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to create agent: {str(e)}")
@router.put(
"/agents/{name}",
response_model=AgentResponse,
summary="Update Custom Agent",
description="Update an existing custom agent's config and/or SOUL.md.",
)
async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
"""Update an existing custom agent.
Args:
name: The agent name.
request: The update request (all fields optional).
Returns:
The updated agent details.
Raises:
HTTPException: 404 if agent not found.
"""
_validate_agent_name(name)
name = _normalize_agent_name(name)
try:
agent_cfg = load_agent_config(name)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
agent_dir = get_paths().agent_dir(name)
try:
# Update config if any config fields changed
config_changed = any(v is not None for v in [request.description, request.model, request.tool_groups])
if config_changed:
updated: dict = {
"name": agent_cfg.name,
"description": request.description if request.description is not None else agent_cfg.description,
}
new_model = request.model if request.model is not None else agent_cfg.model
if new_model is not None:
updated["model"] = new_model
new_tool_groups = request.tool_groups if request.tool_groups is not None else agent_cfg.tool_groups
if new_tool_groups is not None:
updated["tool_groups"] = new_tool_groups
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
yaml.dump(updated, f, default_flow_style=False, allow_unicode=True)
# Update SOUL.md if provided
if request.soul is not None:
soul_path = agent_dir / "SOUL.md"
soul_path.write_text(request.soul, encoding="utf-8")
logger.info(f"Updated agent '{name}'")
refreshed_cfg = load_agent_config(name)
return _agent_config_to_response(refreshed_cfg, include_soul=True)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to update agent '{name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to update agent: {str(e)}")
class UserProfileResponse(BaseModel):
"""Response model for the global user profile (USER.md)."""
content: str | None = Field(default=None, description="USER.md content, or null if not yet created")
class UserProfileUpdateRequest(BaseModel):
"""Request body for setting the global user profile."""
content: str = Field(default="", description="USER.md content — describes the user's background and preferences")
@router.get(
"/user-profile",
response_model=UserProfileResponse,
summary="Get User Profile",
description="Read the global USER.md file that is injected into all custom agents.",
)
async def get_user_profile() -> UserProfileResponse:
"""Return the current USER.md content.
Returns:
UserProfileResponse with content=None if USER.md does not exist yet.
"""
try:
user_md_path = get_paths().user_md_file
if not user_md_path.exists():
return UserProfileResponse(content=None)
raw = user_md_path.read_text(encoding="utf-8").strip()
return UserProfileResponse(content=raw or None)
except Exception as e:
logger.error(f"Failed to read user profile: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to read user profile: {str(e)}")
@router.put(
"/user-profile",
response_model=UserProfileResponse,
summary="Update User Profile",
description="Write the global USER.md file that is injected into all custom agents.",
)
async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileResponse:
"""Create or overwrite the global USER.md.
Args:
request: The update request with the new USER.md content.
Returns:
UserProfileResponse with the saved content.
"""
try:
paths = get_paths()
paths.base_dir.mkdir(parents=True, exist_ok=True)
paths.user_md_file.write_text(request.content, encoding="utf-8")
logger.info(f"Updated USER.md at {paths.user_md_file}")
return UserProfileResponse(content=request.content or None)
except Exception as e:
logger.error(f"Failed to update user profile: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to update user profile: {str(e)}")
@router.delete(
"/agents/{name}",
status_code=204,
summary="Delete Custom Agent",
description="Delete a custom agent and all its files (config, SOUL.md, memory).",
)
async def delete_agent(name: str) -> None:
"""Delete a custom agent.
Args:
name: The agent name.
Raises:
HTTPException: 404 if agent not found.
"""
_validate_agent_name(name)
name = _normalize_agent_name(name)
agent_dir = get_paths().agent_dir(name)
if not agent_dir.exists():
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
try:
shutil.rmtree(agent_dir)
logger.info(f"Deleted agent '{name}' from {agent_dir}")
except Exception as e:
logger.error(f"Failed to delete agent '{name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to delete agent: {str(e)}")

View File

@@ -0,0 +1,181 @@
import logging
import mimetypes
import zipfile
from pathlib import Path
from urllib.parse import quote
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import FileResponse, PlainTextResponse, Response
from app.gateway.path_utils import resolve_thread_virtual_path
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["artifacts"])
ACTIVE_CONTENT_MIME_TYPES = {
"text/html",
"application/xhtml+xml",
"image/svg+xml",
}
def _build_content_disposition(disposition_type: str, filename: str) -> str:
"""Build an RFC 5987 encoded Content-Disposition header value."""
return f"{disposition_type}; filename*=UTF-8''{quote(filename)}"
def _build_attachment_headers(filename: str, extra_headers: dict[str, str] | None = None) -> dict[str, str]:
headers = {"Content-Disposition": _build_content_disposition("attachment", filename)}
if extra_headers:
headers.update(extra_headers)
return headers
def is_text_file_by_content(path: Path, sample_size: int = 8192) -> bool:
"""Check if file is text by examining content for null bytes."""
try:
with open(path, "rb") as f:
chunk = f.read(sample_size)
# Text files shouldn't contain null bytes
return b"\x00" not in chunk
except Exception:
return False
def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None:
"""Extract a file from a .skill ZIP archive.
Args:
zip_path: Path to the .skill file (ZIP archive).
internal_path: Path to the file inside the archive (e.g., "SKILL.md").
Returns:
The file content as bytes, or None if not found.
"""
if not zipfile.is_zipfile(zip_path):
return None
try:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
# List all files in the archive
namelist = zip_ref.namelist()
# Try direct path first
if internal_path in namelist:
return zip_ref.read(internal_path)
# Try with any top-level directory prefix (e.g., "skill-name/SKILL.md")
for name in namelist:
if name.endswith("/" + internal_path) or name == internal_path:
return zip_ref.read(name)
# Not found
return None
except (zipfile.BadZipFile, KeyError):
return None
@router.get(
"/threads/{thread_id}/artifacts/{path:path}",
summary="Get Artifact File",
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
)
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
"""Get an artifact file by its path.
The endpoint automatically detects file types and returns appropriate content types.
Use the `download` query parameter to force file download for non-active content.
Args:
thread_id: The thread ID.
path: The artifact path with virtual prefix (e.g., mnt/user-data/outputs/file.txt).
request: FastAPI request object (automatically injected).
Returns:
The file content as a FileResponse with appropriate content type:
- Active content (HTML/XHTML/SVG): Served as download attachment
- Text files: Plain text with proper MIME type
- Binary files: Inline display with download option
Raises:
HTTPException:
- 400 if path is invalid or not a file
- 403 if access denied (path traversal detected)
- 404 if file not found
Query Parameters:
download (bool): If true, forces attachment download for file types that are
otherwise returned inline or as plain text. Active HTML/XHTML/SVG content
is always downloaded regardless of this flag.
Example:
- Get text file inline: `/api/threads/abc123/artifacts/mnt/user-data/outputs/notes.txt`
- Download file: `/api/threads/abc123/artifacts/mnt/user-data/outputs/data.csv?download=true`
- Active web content such as `.html`, `.xhtml`, and `.svg` artifacts is always downloaded
"""
# Check if this is a request for a file inside a .skill archive (e.g., xxx.skill/SKILL.md)
if ".skill/" in path:
# Split the path at ".skill/" to get the ZIP file path and internal path
skill_marker = ".skill/"
marker_pos = path.find(skill_marker)
skill_file_path = path[: marker_pos + len(".skill")] # e.g., "mnt/user-data/outputs/my-skill.skill"
internal_path = path[marker_pos + len(skill_marker) :] # e.g., "SKILL.md"
actual_skill_path = resolve_thread_virtual_path(thread_id, skill_file_path)
if not actual_skill_path.exists():
raise HTTPException(status_code=404, detail=f"Skill file not found: {skill_file_path}")
if not actual_skill_path.is_file():
raise HTTPException(status_code=400, detail=f"Path is not a file: {skill_file_path}")
# Extract the file from the .skill archive
content = _extract_file_from_skill_archive(actual_skill_path, internal_path)
if content is None:
raise HTTPException(status_code=404, detail=f"File '{internal_path}' not found in skill archive")
# Determine MIME type based on the internal file
mime_type, _ = mimetypes.guess_type(internal_path)
# Add cache headers to avoid repeated ZIP extraction (cache for 5 minutes)
cache_headers = {"Cache-Control": "private, max-age=300"}
download_name = Path(internal_path).name or actual_skill_path.stem
if download or mime_type in ACTIVE_CONTENT_MIME_TYPES:
return Response(content=content, media_type=mime_type or "application/octet-stream", headers=_build_attachment_headers(download_name, cache_headers))
if mime_type and mime_type.startswith("text/"):
return PlainTextResponse(content=content.decode("utf-8"), media_type=mime_type, headers=cache_headers)
# Default to plain text for unknown types that look like text
try:
return PlainTextResponse(content=content.decode("utf-8"), media_type="text/plain", headers=cache_headers)
except UnicodeDecodeError:
return Response(content=content, media_type=mime_type or "application/octet-stream", headers=cache_headers)
actual_path = resolve_thread_virtual_path(thread_id, path)
logger.info(f"Resolving artifact path: thread_id={thread_id}, requested_path={path}, actual_path={actual_path}")
if not actual_path.exists():
raise HTTPException(status_code=404, detail=f"Artifact not found: {path}")
if not actual_path.is_file():
raise HTTPException(status_code=400, detail=f"Path is not a file: {path}")
mime_type, _ = mimetypes.guess_type(actual_path)
if download:
return FileResponse(path=actual_path, filename=actual_path.name, media_type=mime_type, headers=_build_attachment_headers(actual_path.name))
# Always force download for active content types to prevent script execution
# in the application origin when users open generated artifacts.
if mime_type in ACTIVE_CONTENT_MIME_TYPES:
return FileResponse(path=actual_path, filename=actual_path.name, media_type=mime_type, headers=_build_attachment_headers(actual_path.name))
if mime_type and mime_type.startswith("text/"):
return PlainTextResponse(content=actual_path.read_text(encoding="utf-8"), media_type=mime_type)
if is_text_file_by_content(actual_path):
return PlainTextResponse(content=actual_path.read_text(encoding="utf-8"), media_type=mime_type)
return Response(content=actual_path.read_bytes(), media_type=mime_type, headers={"Content-Disposition": _build_content_disposition("inline", actual_path.name)})

View File

@@ -0,0 +1,149 @@
"""Assistants compatibility endpoints.
Provides LangGraph Platform-compatible assistants API backed by the
``langgraph.json`` graph registry and ``config.yaml`` agent definitions.
This is a minimal stub that satisfies the ``useStream`` React hook's
initialization requirements (``assistants.search()`` and ``assistants.get()``).
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime
from typing import Any
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/assistants", tags=["assistants-compat"])
class AssistantResponse(BaseModel):
assistant_id: str
graph_id: str
name: str
config: dict[str, Any] = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
description: str | None = None
created_at: str = ""
updated_at: str = ""
version: int = 1
class AssistantSearchRequest(BaseModel):
graph_id: str | None = None
name: str | None = None
metadata: dict[str, Any] | None = None
limit: int = 10
offset: int = 0
def _get_default_assistant() -> AssistantResponse:
"""Return the default lead_agent assistant."""
now = datetime.now(UTC).isoformat()
return AssistantResponse(
assistant_id="lead_agent",
graph_id="lead_agent",
name="lead_agent",
config={},
metadata={"created_by": "system"},
description="DeerFlow lead agent",
created_at=now,
updated_at=now,
version=1,
)
def _list_assistants() -> list[AssistantResponse]:
"""List all available assistants from config."""
assistants = [_get_default_assistant()]
# Also include custom agents from config.yaml agents directory
try:
from deerflow.config.agents_config import list_custom_agents
for agent_cfg in list_custom_agents():
now = datetime.now(UTC).isoformat()
assistants.append(
AssistantResponse(
assistant_id=agent_cfg.name,
graph_id="lead_agent", # All agents use the same graph
name=agent_cfg.name,
config={},
metadata={"created_by": "user"},
description=agent_cfg.description or "",
created_at=now,
updated_at=now,
version=1,
)
)
except Exception:
logger.debug("Could not load custom agents for assistants list")
return assistants
@router.post("/search", response_model=list[AssistantResponse])
async def search_assistants(body: AssistantSearchRequest | None = None) -> list[AssistantResponse]:
"""Search assistants.
Returns all registered assistants (lead_agent + custom agents from config).
"""
assistants = _list_assistants()
if body and body.graph_id:
assistants = [a for a in assistants if a.graph_id == body.graph_id]
if body and body.name:
assistants = [a for a in assistants if body.name.lower() in a.name.lower()]
offset = body.offset if body else 0
limit = body.limit if body else 10
return assistants[offset : offset + limit]
@router.get("/{assistant_id}", response_model=AssistantResponse)
async def get_assistant_compat(assistant_id: str) -> AssistantResponse:
"""Get an assistant by ID."""
for a in _list_assistants():
if a.assistant_id == assistant_id:
return a
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
@router.get("/{assistant_id}/graph")
async def get_assistant_graph(assistant_id: str) -> dict:
"""Get the graph structure for an assistant.
Returns a minimal graph description. Full graph introspection is
not supported in the Gateway — this stub satisfies SDK validation.
"""
found = any(a.assistant_id == assistant_id for a in _list_assistants())
if not found:
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
return {
"graph_id": "lead_agent",
"nodes": [],
"edges": [],
}
@router.get("/{assistant_id}/schemas")
async def get_assistant_schemas(assistant_id: str) -> dict:
"""Get JSON schemas for an assistant's input/output/state.
Returns empty schemas — full introspection not supported in Gateway.
"""
found = any(a.assistant_id == assistant_id for a in _list_assistants())
if not found:
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
return {
"graph_id": "lead_agent",
"input_schema": {},
"output_schema": {},
"state_schema": {},
"config_schema": {},
}

View File

@@ -0,0 +1,52 @@
"""Gateway router for IM channel management."""
from __future__ import annotations
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/channels", tags=["channels"])
class ChannelStatusResponse(BaseModel):
service_running: bool
channels: dict[str, dict]
class ChannelRestartResponse(BaseModel):
success: bool
message: str
@router.get("/", response_model=ChannelStatusResponse)
async def get_channels_status() -> ChannelStatusResponse:
"""Get the status of all IM channels."""
from app.channels.service import get_channel_service
service = get_channel_service()
if service is None:
return ChannelStatusResponse(service_running=False, channels={})
status = service.get_status()
return ChannelStatusResponse(**status)
@router.post("/{name}/restart", response_model=ChannelRestartResponse)
async def restart_channel(name: str) -> ChannelRestartResponse:
"""Restart a specific IM channel."""
from app.channels.service import get_channel_service
service = get_channel_service()
if service is None:
raise HTTPException(status_code=503, detail="Channel service is not running")
success = await service.restart_channel(name)
if success:
logger.info("Channel %s restarted successfully", name)
return ChannelRestartResponse(success=True, message=f"Channel {name} restarted successfully")
else:
logger.warning("Failed to restart channel %s", name)
return ChannelRestartResponse(success=False, message=f"Failed to restart channel {name}")

View File

@@ -0,0 +1,169 @@
import json
import logging
from pathlib import Path
from typing import Literal
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["mcp"])
class McpOAuthConfigResponse(BaseModel):
"""OAuth configuration for an MCP server."""
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
token_url: str = Field(default="", description="OAuth token endpoint URL")
grant_type: Literal["client_credentials", "refresh_token"] = Field(default="client_credentials", description="OAuth grant type")
client_id: str | None = Field(default=None, description="OAuth client ID")
client_secret: str | None = Field(default=None, description="OAuth client secret")
refresh_token: str | None = Field(default=None, description="OAuth refresh token")
scope: str | None = Field(default=None, description="OAuth scope")
audience: str | None = Field(default=None, description="OAuth audience")
token_field: str = Field(default="access_token", description="Token response field containing access token")
token_type_field: str = Field(default="token_type", description="Token response field containing token type")
expires_in_field: str = Field(default="expires_in", description="Token response field containing expires-in seconds")
default_token_type: str = Field(default="Bearer", description="Default token type when response omits token_type")
refresh_skew_seconds: int = Field(default=60, description="Refresh this many seconds before expiry")
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
class McpServerConfigResponse(BaseModel):
"""Response model for MCP server configuration."""
enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'")
command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)")
args: list[str] = Field(default_factory=list, description="Arguments to pass to the command (for stdio type)")
env: dict[str, str] = Field(default_factory=dict, description="Environment variables for the MCP server")
url: str | None = Field(default=None, description="URL of the MCP server (for sse or http type)")
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
oauth: McpOAuthConfigResponse | None = Field(default=None, description="OAuth configuration for MCP HTTP/SSE servers")
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
class McpConfigResponse(BaseModel):
"""Response model for MCP configuration."""
mcp_servers: dict[str, McpServerConfigResponse] = Field(
default_factory=dict,
description="Map of MCP server name to configuration",
)
class McpConfigUpdateRequest(BaseModel):
"""Request model for updating MCP configuration."""
mcp_servers: dict[str, McpServerConfigResponse] = Field(
...,
description="Map of MCP server name to configuration",
)
@router.get(
"/mcp/config",
response_model=McpConfigResponse,
summary="Get MCP Configuration",
description="Retrieve the current Model Context Protocol (MCP) server configurations.",
)
async def get_mcp_configuration() -> McpConfigResponse:
"""Get the current MCP configuration.
Returns:
The current MCP configuration with all servers.
Example:
```json
{
"mcp_servers": {
"github": {
"enabled": true,
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {"GITHUB_TOKEN": "ghp_xxx"},
"description": "GitHub MCP server for repository operations"
}
}
}
```
"""
config = get_extensions_config()
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in config.mcp_servers.items()})
@router.put(
"/mcp/config",
response_model=McpConfigResponse,
summary="Update MCP Configuration",
description="Update Model Context Protocol (MCP) server configurations and save to file.",
)
async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfigResponse:
"""Update the MCP configuration.
This will:
1. Save the new configuration to the mcp_config.json file
2. Reload the configuration cache
3. Reset MCP tools cache to trigger reinitialization
Args:
request: The new MCP configuration to save.
Returns:
The updated MCP configuration.
Raises:
HTTPException: 500 if the configuration file cannot be written.
Example Request:
```json
{
"mcp_servers": {
"github": {
"enabled": true,
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {"GITHUB_TOKEN": "$GITHUB_TOKEN"},
"description": "GitHub MCP server for repository operations"
}
}
}
```
"""
try:
# Get the current config path (or determine where to save it)
config_path = ExtensionsConfig.resolve_config_path()
# If no config file exists, create one in the parent directory (project root)
if config_path is None:
config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
# Load current config to preserve skills configuration
current_config = get_extensions_config()
# Convert request to dict format for JSON serialization
config_data = {
"mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
}
# Write the configuration to file
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
logger.info(f"MCP configuration updated and saved to: {config_path}")
# NOTE: No need to reload/reset cache here - LangGraph Server (separate process)
# will detect config file changes via mtime and reinitialize MCP tools automatically
# Reload the configuration and update the global cache
reloaded_config = reload_extensions_config()
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_config.mcp_servers.items()})
except Exception as e:
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to update MCP configuration: {str(e)}")

View File

@@ -0,0 +1,353 @@
"""Memory API router for retrieving and managing global memory data."""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.agents.memory.updater import (
clear_memory_data,
create_memory_fact,
delete_memory_fact,
get_memory_data,
import_memory_data,
reload_memory_data,
update_memory_fact,
)
from deerflow.config.memory_config import get_memory_config
router = APIRouter(prefix="/api", tags=["memory"])
class ContextSection(BaseModel):
"""Model for context sections (user and history)."""
summary: str = Field(default="", description="Summary content")
updatedAt: str = Field(default="", description="Last update timestamp")
class UserContext(BaseModel):
"""Model for user context."""
workContext: ContextSection = Field(default_factory=ContextSection)
personalContext: ContextSection = Field(default_factory=ContextSection)
topOfMind: ContextSection = Field(default_factory=ContextSection)
class HistoryContext(BaseModel):
"""Model for history context."""
recentMonths: ContextSection = Field(default_factory=ContextSection)
earlierContext: ContextSection = Field(default_factory=ContextSection)
longTermBackground: ContextSection = Field(default_factory=ContextSection)
class Fact(BaseModel):
"""Model for a memory fact."""
id: str = Field(..., description="Unique identifier for the fact")
content: str = Field(..., description="Fact content")
category: str = Field(default="context", description="Fact category")
confidence: float = Field(default=0.5, description="Confidence score (0-1)")
createdAt: str = Field(default="", description="Creation timestamp")
source: str = Field(default="unknown", description="Source thread ID")
sourceError: str | None = Field(default=None, description="Optional description of the prior mistake or wrong approach")
class MemoryResponse(BaseModel):
"""Response model for memory data."""
version: str = Field(default="1.0", description="Memory schema version")
lastUpdated: str = Field(default="", description="Last update timestamp")
user: UserContext = Field(default_factory=UserContext)
history: HistoryContext = Field(default_factory=HistoryContext)
facts: list[Fact] = Field(default_factory=list)
def _map_memory_fact_value_error(exc: ValueError) -> HTTPException:
"""Convert updater validation errors into stable API responses."""
if exc.args and exc.args[0] == "confidence":
detail = "Invalid confidence value; must be between 0 and 1."
else:
detail = "Memory fact content cannot be empty."
return HTTPException(status_code=400, detail=detail)
class FactCreateRequest(BaseModel):
"""Request model for creating a memory fact."""
content: str = Field(..., min_length=1, description="Fact content")
category: str = Field(default="context", description="Fact category")
confidence: float = Field(default=0.5, ge=0.0, le=1.0, description="Confidence score (0-1)")
class FactPatchRequest(BaseModel):
"""PATCH request model that preserves existing values for omitted fields."""
content: str | None = Field(default=None, min_length=1, description="Fact content")
category: str | None = Field(default=None, description="Fact category")
confidence: float | None = Field(default=None, ge=0.0, le=1.0, description="Confidence score (0-1)")
class MemoryConfigResponse(BaseModel):
"""Response model for memory configuration."""
enabled: bool = Field(..., description="Whether memory is enabled")
storage_path: str = Field(..., description="Path to memory storage file")
debounce_seconds: int = Field(..., description="Debounce time for memory updates")
max_facts: int = Field(..., description="Maximum number of facts to store")
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
class MemoryStatusResponse(BaseModel):
"""Response model for memory status."""
config: MemoryConfigResponse
data: MemoryResponse
@router.get(
"/memory",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Get Memory Data",
description="Retrieve the current global memory data including user context, history, and facts.",
)
async def get_memory() -> MemoryResponse:
"""Get the current global memory data.
Returns:
The current memory data with user context, history, and facts.
Example Response:
```json
{
"version": "1.0",
"lastUpdated": "2024-01-15T10:30:00Z",
"user": {
"workContext": {"summary": "Working on DeerFlow project", "updatedAt": "..."},
"personalContext": {"summary": "Prefers concise responses", "updatedAt": "..."},
"topOfMind": {"summary": "Building memory API", "updatedAt": "..."}
},
"history": {
"recentMonths": {"summary": "Recent development activities", "updatedAt": "..."},
"earlierContext": {"summary": "", "updatedAt": ""},
"longTermBackground": {"summary": "", "updatedAt": ""}
},
"facts": [
{
"id": "fact_abc123",
"content": "User prefers TypeScript over JavaScript",
"category": "preference",
"confidence": 0.9,
"createdAt": "2024-01-15T10:30:00Z",
"source": "thread_xyz"
}
]
}
```
"""
memory_data = get_memory_data()
return MemoryResponse(**memory_data)
@router.post(
"/memory/reload",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Reload Memory Data",
description="Reload memory data from the storage file, refreshing the in-memory cache.",
)
async def reload_memory() -> MemoryResponse:
"""Reload memory data from file.
This forces a reload of the memory data from the storage file,
useful when the file has been modified externally.
Returns:
The reloaded memory data.
"""
memory_data = reload_memory_data()
return MemoryResponse(**memory_data)
@router.delete(
"/memory",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Clear All Memory Data",
description="Delete all saved memory data and reset the memory structure to an empty state.",
)
async def clear_memory() -> MemoryResponse:
"""Clear all persisted memory data."""
try:
memory_data = clear_memory_data()
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
return MemoryResponse(**memory_data)
@router.post(
"/memory/facts",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Create Memory Fact",
description="Create a single saved memory fact manually.",
)
async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryResponse:
"""Create a single fact manually."""
try:
memory_data = create_memory_fact(
content=request.content,
category=request.category,
confidence=request.confidence,
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
return MemoryResponse(**memory_data)
@router.delete(
"/memory/facts/{fact_id}",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Delete Memory Fact",
description="Delete a single saved memory fact by its fact id.",
)
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
"""Delete a single fact from memory by fact id."""
try:
memory_data = delete_memory_fact(fact_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
return MemoryResponse(**memory_data)
@router.patch(
"/memory/facts/{fact_id}",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Patch Memory Fact",
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
)
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -> MemoryResponse:
"""Partially update a single fact manually."""
try:
memory_data = update_memory_fact(
fact_id=fact_id,
content=request.content,
category=request.category,
confidence=request.confidence,
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
return MemoryResponse(**memory_data)
@router.get(
"/memory/export",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Export Memory Data",
description="Export the current global memory data as JSON for backup or transfer.",
)
async def export_memory() -> MemoryResponse:
"""Export the current memory data."""
memory_data = get_memory_data()
return MemoryResponse(**memory_data)
@router.post(
"/memory/import",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Import Memory Data",
description="Import and overwrite the current global memory data from a JSON payload.",
)
async def import_memory(request: MemoryResponse) -> MemoryResponse:
"""Import and persist memory data."""
try:
memory_data = import_memory_data(request.model_dump())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
return MemoryResponse(**memory_data)
@router.get(
"/memory/config",
response_model=MemoryConfigResponse,
summary="Get Memory Configuration",
description="Retrieve the current memory system configuration.",
)
async def get_memory_config_endpoint() -> MemoryConfigResponse:
"""Get the memory system configuration.
Returns:
The current memory configuration settings.
Example Response:
```json
{
"enabled": true,
"storage_path": ".deer-flow/memory.json",
"debounce_seconds": 30,
"max_facts": 100,
"fact_confidence_threshold": 0.7,
"injection_enabled": true,
"max_injection_tokens": 2000
}
```
"""
config = get_memory_config()
return MemoryConfigResponse(
enabled=config.enabled,
storage_path=config.storage_path,
debounce_seconds=config.debounce_seconds,
max_facts=config.max_facts,
fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens,
)
@router.get(
"/memory/status",
response_model=MemoryStatusResponse,
response_model_exclude_none=True,
summary="Get Memory Status",
description="Retrieve both memory configuration and current data in a single request.",
)
async def get_memory_status() -> MemoryStatusResponse:
"""Get the memory system status including configuration and data.
Returns:
Combined memory configuration and current data.
"""
config = get_memory_config()
memory_data = get_memory_data()
return MemoryStatusResponse(
config=MemoryConfigResponse(
enabled=config.enabled,
storage_path=config.storage_path,
debounce_seconds=config.debounce_seconds,
max_facts=config.max_facts,
fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens,
),
data=MemoryResponse(**memory_data),
)

View File

@@ -0,0 +1,116 @@
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.config import get_app_config
router = APIRouter(prefix="/api", tags=["models"])
class ModelResponse(BaseModel):
"""Response model for model information."""
name: str = Field(..., description="Unique identifier for the model")
model: str = Field(..., description="Actual provider model identifier")
display_name: str | None = Field(None, description="Human-readable name")
description: str | None = Field(None, description="Model description")
supports_thinking: bool = Field(default=False, description="Whether model supports thinking mode")
supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort")
class ModelsListResponse(BaseModel):
"""Response model for listing all models."""
models: list[ModelResponse]
@router.get(
"/models",
response_model=ModelsListResponse,
summary="List All Models",
description="Retrieve a list of all available AI models configured in the system.",
)
async def list_models() -> ModelsListResponse:
"""List all available models from configuration.
Returns model information suitable for frontend display,
excluding sensitive fields like API keys and internal configuration.
Returns:
A list of all configured models with their metadata.
Example Response:
```json
{
"models": [
{
"name": "gpt-4",
"display_name": "GPT-4",
"description": "OpenAI GPT-4 model",
"supports_thinking": false
},
{
"name": "claude-3-opus",
"display_name": "Claude 3 Opus",
"description": "Anthropic Claude 3 Opus model",
"supports_thinking": true
}
]
}
```
"""
config = get_app_config()
models = [
ModelResponse(
name=model.name,
model=model.model,
display_name=model.display_name,
description=model.description,
supports_thinking=model.supports_thinking,
supports_reasoning_effort=model.supports_reasoning_effort,
)
for model in config.models
]
return ModelsListResponse(models=models)
@router.get(
"/models/{model_name}",
response_model=ModelResponse,
summary="Get Model Details",
description="Retrieve detailed information about a specific AI model by its name.",
)
async def get_model(model_name: str) -> ModelResponse:
"""Get a specific model by name.
Args:
model_name: The unique name of the model to retrieve.
Returns:
Model information if found.
Raises:
HTTPException: 404 if model not found.
Example Response:
```json
{
"name": "gpt-4",
"display_name": "GPT-4",
"description": "OpenAI GPT-4 model",
"supports_thinking": false
}
```
"""
config = get_app_config()
model = config.get_model_config(model_name)
if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
return ModelResponse(
name=model.name,
model=model.model,
display_name=model.display_name,
description=model.description,
supports_thinking=model.supports_thinking,
supports_reasoning_effort=model.supports_reasoning_effort,
)

View File

@@ -0,0 +1,87 @@
"""Stateless runs endpoints -- stream and wait without a pre-existing thread.
These endpoints auto-create a temporary thread when no ``thread_id`` is
supplied in the request body. When a ``thread_id`` **is** provided, it
is reused so that conversation history is preserved across calls.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
from app.gateway.routers.thread_runs import RunCreateRequest
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import serialize_channel_values
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/runs", tags=["runs"])
def _resolve_thread_id(body: RunCreateRequest) -> str:
"""Return the thread_id from the request body, or generate a new one."""
thread_id = (body.config or {}).get("configurable", {}).get("thread_id")
if thread_id:
return str(thread_id)
return str(uuid.uuid4())
@router.post("/stream")
async def stateless_stream(body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE.
If ``config.configurable.thread_id`` is provided, the run is created
on the given thread so that conversation history is preserved.
Otherwise a new temporary thread is created.
"""
thread_id = _resolve_thread_id(body)
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
},
)
@router.post("/wait", response_model=dict)
async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until completion.
If ``config.configurable.thread_id`` is provided, the run is created
on the given thread so that conversation history is preserved.
Otherwise a new temporary thread is created.
"""
thread_id = _resolve_thread_id(body)
record = await start_run(body, thread_id, request)
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}

View File

@@ -0,0 +1,356 @@
import json
import logging
import shutil
from pathlib import Path
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
from deerflow.skills.manager import (
append_history,
atomic_write,
custom_skill_exists,
ensure_custom_skill_is_editable,
get_custom_skill_dir,
get_custom_skill_file,
get_skill_history_file,
read_custom_skill_content,
read_history,
validate_skill_markdown_content,
)
from deerflow.skills.security_scanner import scan_skill_content
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["skills"])
class SkillResponse(BaseModel):
"""Response model for skill information."""
name: str = Field(..., description="Name of the skill")
description: str = Field(..., description="Description of what the skill does")
license: str | None = Field(None, description="License information")
category: str = Field(..., description="Category of the skill (public or custom)")
enabled: bool = Field(default=True, description="Whether this skill is enabled")
class SkillsListResponse(BaseModel):
"""Response model for listing all skills."""
skills: list[SkillResponse]
class SkillUpdateRequest(BaseModel):
"""Request model for updating a skill."""
enabled: bool = Field(..., description="Whether to enable or disable the skill")
class SkillInstallRequest(BaseModel):
"""Request model for installing a skill from a .skill file."""
thread_id: str = Field(..., description="The thread ID where the .skill file is located")
path: str = Field(..., description="Virtual path to the .skill file (e.g., mnt/user-data/outputs/my-skill.skill)")
class SkillInstallResponse(BaseModel):
"""Response model for skill installation."""
success: bool = Field(..., description="Whether the installation was successful")
skill_name: str = Field(..., description="Name of the installed skill")
message: str = Field(..., description="Installation result message")
class CustomSkillContentResponse(SkillResponse):
content: str = Field(..., description="Raw SKILL.md content")
class CustomSkillUpdateRequest(BaseModel):
content: str = Field(..., description="Replacement SKILL.md content")
class CustomSkillHistoryResponse(BaseModel):
history: list[dict]
class SkillRollbackRequest(BaseModel):
history_index: int = Field(default=-1, description="History entry index to restore from, defaulting to the latest change.")
def _skill_to_response(skill: Skill) -> SkillResponse:
"""Convert a Skill object to a SkillResponse."""
return SkillResponse(
name=skill.name,
description=skill.description,
license=skill.license,
category=skill.category,
enabled=skill.enabled,
)
@router.get(
"/skills",
response_model=SkillsListResponse,
summary="List All Skills",
description="Retrieve a list of all available skills from both public and custom directories.",
)
async def list_skills() -> SkillsListResponse:
try:
skills = load_skills(enabled_only=False)
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error(f"Failed to load skills: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to load skills: {str(e)}")
@router.post(
"/skills/install",
response_model=SkillInstallResponse,
summary="Install Skill",
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
)
async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
try:
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
result = install_skill_from_archive(skill_file_path)
await refresh_skills_system_prompt_cache_async()
return SkillInstallResponse(**result)
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except SkillAlreadyExistsError as e:
raise HTTPException(status_code=409, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to install skill: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to install skill: {str(e)}")
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
async def list_custom_skills() -> SkillsListResponse:
try:
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error("Failed to list custom skills: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to list custom skills: {str(e)}")
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
try:
skills = load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
if skill is None:
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name))
except HTTPException:
raise
except Exception as e:
logger.error("Failed to get custom skill %s: %s", skill_name, e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get custom skill: {str(e)}")
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse:
try:
ensure_custom_skill_is_editable(skill_name)
validate_skill_markdown_content(skill_name, request.content)
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md")
if scan.decision == "block":
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md"
prev_content = skill_file.read_text(encoding="utf-8")
atomic_write(skill_file, request.content)
append_history(
skill_name,
{
"action": "human_edit",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": prev_content,
"new_content": request.content,
"scanner": {"decision": scan.decision, "reason": scan.reason},
},
)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
except HTTPException:
raise
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Failed to update custom skill %s: %s", skill_name, e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to update custom skill: {str(e)}")
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
try:
ensure_custom_skill_is_editable(skill_name)
skill_dir = get_custom_skill_dir(skill_name)
prev_content = read_custom_skill_content(skill_name)
append_history(
skill_name,
{
"action": "human_delete",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": prev_content,
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
)
shutil.rmtree(skill_dir)
await refresh_skills_system_prompt_cache_async()
return {"success": True}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Failed to delete custom skill %s: %s", skill_name, e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to delete custom skill: {str(e)}")
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillHistoryResponse(history=read_history(skill_name))
except HTTPException:
raise
except Exception as e:
logger.error("Failed to read history for %s: %s", skill_name, e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to read history: {str(e)}")
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
history = read_history(skill_name)
if not history:
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
record = history[request.history_index]
target_content = record.get("prev_content")
if target_content is None:
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
validate_skill_markdown_content(skill_name, target_content)
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md")
skill_file = get_custom_skill_file(skill_name)
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
history_entry = {
"action": "rollback",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": current_content,
"new_content": target_content,
"rollback_from_ts": record.get("ts"),
"scanner": {"decision": scan.decision, "reason": scan.reason},
}
if scan.decision == "block":
append_history(skill_name, history_entry)
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
atomic_write(skill_file, target_content)
append_history(skill_name, history_entry)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
except HTTPException:
raise
except IndexError:
raise HTTPException(status_code=400, detail="history_index is out of range")
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Failed to roll back custom skill %s: %s", skill_name, e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to roll back custom skill: {str(e)}")
@router.get(
"/skills/{skill_name}",
response_model=SkillResponse,
summary="Get Skill Details",
description="Retrieve detailed information about a specific skill by its name.",
)
async def get_skill(skill_name: str) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
return _skill_to_response(skill)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get skill {skill_name}: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get skill: {str(e)}")
@router.put(
"/skills/{skill_name}",
response_model=SkillResponse,
summary="Update Skill",
description="Update a skill's enabled status by modifying the extensions_config.json file.",
)
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
config_path = ExtensionsConfig.resolve_config_path()
if config_path is None:
config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
extensions_config = get_extensions_config()
extensions_config.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
config_data = {
"mcpServers": {name: server.model_dump() for name, server in extensions_config.mcp_servers.items()},
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in extensions_config.skills.items()},
}
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
logger.info(f"Skills configuration updated and saved to: {config_path}")
reload_extensions_config()
await refresh_skills_system_prompt_cache_async()
skills = load_skills(enabled_only=False)
updated_skill = next((s for s in skills if s.name == skill_name), None)
if updated_skill is None:
raise HTTPException(status_code=500, detail=f"Failed to reload skill '{skill_name}' after update")
logger.info(f"Skill '{skill_name}' enabled status updated to {request.enabled}")
return _skill_to_response(updated_skill)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to update skill {skill_name}: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to update skill: {str(e)}")

View File

@@ -0,0 +1,132 @@
import json
import logging
from fastapi import APIRouter
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["suggestions"])
class SuggestionMessage(BaseModel):
role: str = Field(..., description="Message role: user|assistant")
content: str = Field(..., description="Message content as plain text")
class SuggestionsRequest(BaseModel):
messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages")
n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate")
model_name: str | None = Field(default=None, description="Optional model override")
class SuggestionsResponse(BaseModel):
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
def _strip_markdown_code_fence(text: str) -> str:
stripped = text.strip()
if not stripped.startswith("```"):
return stripped
lines = stripped.splitlines()
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
return "\n".join(lines[1:-1]).strip()
return stripped
def _parse_json_string_list(text: str) -> list[str] | None:
candidate = _strip_markdown_code_fence(text)
start = candidate.find("[")
end = candidate.rfind("]")
if start == -1 or end == -1 or end <= start:
return None
candidate = candidate[start : end + 1]
try:
data = json.loads(candidate)
except Exception:
return None
if not isinstance(data, list):
return None
out: list[str] = []
for item in data:
if not isinstance(item, str):
continue
s = item.strip()
if not s:
continue
out.append(s)
return out
def _extract_response_text(content: object) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict) and block.get("type") in {"text", "output_text"}:
text = block.get("text")
if isinstance(text, str):
parts.append(text)
return "\n".join(parts) if parts else ""
if content is None:
return ""
return str(content)
def _format_conversation(messages: list[SuggestionMessage]) -> str:
parts: list[str] = []
for m in messages:
role = m.role.strip().lower()
if role in ("user", "human"):
parts.append(f"User: {m.content.strip()}")
elif role in ("assistant", "ai"):
parts.append(f"Assistant: {m.content.strip()}")
else:
parts.append(f"{m.role}: {m.content.strip()}")
return "\n".join(parts).strip()
@router.post(
"/threads/{thread_id}/suggestions",
response_model=SuggestionsResponse,
summary="Generate Follow-up Questions",
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
)
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
if not request.messages:
return SuggestionsResponse(suggestions=[])
n = request.n
conversation = _format_conversation(request.messages)
if not conversation:
return SuggestionsResponse(suggestions=[])
system_instruction = (
"You are generating follow-up questions to help the user continue the conversation.\n"
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
"Requirements:\n"
"- Questions must be relevant to the preceding conversation.\n"
"- Questions must be written in the same language as the user.\n"
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
"- Do NOT include numbering, markdown, or any extra text.\n"
"- Output MUST be a JSON array of strings only.\n"
)
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try:
model = create_chat_model(name=request.model_name, thinking_enabled=False)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or []
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
cleaned = cleaned[:n]
return SuggestionsResponse(suggestions=cleaned)
except Exception as exc:
logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc)
return SuggestionsResponse(suggestions=[])

View File

@@ -0,0 +1,267 @@
"""Runs endpoints — create, stream, wait, cancel.
Implements the LangGraph Platform runs API on top of
:class:`deerflow.agents.runs.RunManager` and
:class:`deerflow.agents.stream_bridge.StreamBridge`.
SSE format is aligned with the LangGraph Platform protocol so that
the ``useStream`` React hook from ``@langchain/langgraph-sdk/react``
works without modification.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Literal
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import RunRecord, serialize_channel_values
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["runs"])
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class RunCreateRequest(BaseModel):
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
input: dict[str, Any] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
webhook: str | None = Field(default=None, description="Completion callback URL")
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
after_seconds: float | None = Field(default=None, description="Delayed execution")
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
class RunResponse(BaseModel):
run_id: str
thread_id: str
assistant_id: str | None = None
status: str
metadata: dict[str, Any] = Field(default_factory=dict)
kwargs: dict[str, Any] = Field(default_factory=dict)
multitask_strategy: str = "reject"
created_at: str = ""
updated_at: str = ""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _record_to_response(record: RunRecord) -> RunResponse:
return RunResponse(
run_id=record.run_id,
thread_id=record.thread_id,
assistant_id=record.assistant_id,
status=record.status.value,
metadata=record.metadata,
kwargs=record.kwargs,
multitask_strategy=record.multitask_strategy,
created_at=record.created_at,
updated_at=record.updated_at,
)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/{thread_id}/runs", response_model=RunResponse)
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
"""Create a background run (returns immediately)."""
record = await start_run(body, thread_id, request)
return _record_to_response(record)
@router.post("/{thread_id}/runs/stream")
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE.
The response includes a ``Content-Location`` header with the run's
resource URL, matching the LangGraph Platform protocol. The
``useStream`` React hook uses this to extract run metadata.
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
# LangGraph Platform includes run metadata in this header.
# The SDK uses a greedy regex to extract the run id from this path,
# so it must point at the canonical run resource without extra suffixes.
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
},
)
@router.post("/{thread_id}/runs/wait", response_model=dict)
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until it completes, returning the final state."""
record = await start_run(body, thread_id, request)
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
"""List all runs for a thread."""
run_mgr = get_run_manager(request)
records = await run_mgr.list_by_thread(thread_id)
return [_record_to_response(r) for r in records]
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
"""Get details of a specific run."""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return _record_to_response(record)
@router.post("/{thread_id}/runs/{run_id}/cancel")
async def cancel_run(
thread_id: str,
run_id: str,
request: Request,
wait: bool = Query(default=False, description="Block until run completes after cancel"),
action: Literal["interrupt", "rollback"] = Query(default="interrupt", description="Cancel action"),
) -> Response:
"""Cancel a running or pending run.
- action=interrupt: Stop execution, keep current checkpoint (can be resumed)
- action=rollback: Stop execution, revert to pre-run checkpoint state
- wait=true: Block until the run fully stops, return 204
- wait=false: Return immediately with 202
"""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
cancelled = await run_mgr.cancel(run_id, action=action)
if not cancelled:
raise HTTPException(
status_code=409,
detail=f"Run {run_id} is not cancellable (status: {record.status.value})",
)
if wait and record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
return Response(status_code=204)
return Response(status_code=202)
@router.get("/{thread_id}/runs/{run_id}/join")
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
"""Join an existing run's SSE stream."""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
async def stream_existing_run(
thread_id: str,
run_id: str,
request: Request,
action: Literal["interrupt", "rollback"] | None = Query(default=None, description="Cancel action"),
wait: int = Query(default=0, description="Block until cancelled (1) or return immediately (0)"),
):
"""Join an existing run's SSE stream (GET), or cancel-then-stream (POST).
The LangGraph SDK's ``joinStream`` and ``useStream`` stop button both use
``POST`` to this endpoint. When ``action=interrupt`` or ``action=rollback``
is present the run is cancelled first; the response then streams any
remaining buffered events so the client observes a clean shutdown.
"""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
# Cancel if an action was requested (stop-button / interrupt flow)
if action is not None:
cancelled = await run_mgr.cancel(run_id, action=action)
if cancelled and wait and record.task is not None:
try:
await record.task
except (asyncio.CancelledError, Exception):
pass
return Response(status_code=204)
bridge = get_stream_bridge(request)
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)

View File

@@ -0,0 +1,682 @@
"""Thread CRUD, state, and history endpoints.
Combines the existing thread-local filesystem cleanup with LangGraph
Platform-compatible thread management backed by the checkpointer.
Channel values returned in state responses are serialized through
:func:`deerflow.runtime.serialization.serialize_channel_values` to
ensure LangChain message objects are converted to JSON-safe dicts
matching the LangGraph Platform wire format expected by the
``useStream`` React hook.
"""
from __future__ import annotations
import logging
import time
import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from app.gateway.deps import get_checkpointer, get_store
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
# ---------------------------------------------------------------------------
# Store namespace
# ---------------------------------------------------------------------------
THREADS_NS: tuple[str, ...] = ("threads",)
"""Namespace used by the Store for thread metadata records."""
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["threads"])
# ---------------------------------------------------------------------------
# Response / request models
# ---------------------------------------------------------------------------
class ThreadDeleteResponse(BaseModel):
"""Response model for thread cleanup."""
success: bool
message: str
class ThreadResponse(BaseModel):
"""Response model for a single thread."""
thread_id: str = Field(description="Unique thread identifier")
status: str = Field(default="idle", description="Thread status: idle, busy, interrupted, error")
created_at: str = Field(default="", description="ISO timestamp")
updated_at: str = Field(default="", description="ISO timestamp")
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
values: dict[str, Any] = Field(default_factory=dict, description="Current state channel values")
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
class ThreadCreateRequest(BaseModel):
"""Request body for creating a thread."""
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
class ThreadSearchRequest(BaseModel):
"""Request body for searching threads."""
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
offset: int = Field(default=0, ge=0, description="Pagination offset")
status: str | None = Field(default=None, description="Filter by thread status")
class ThreadStateResponse(BaseModel):
"""Response model for thread state."""
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
next: list[str] = Field(default_factory=list, description="Next tasks to execute")
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
class ThreadPatchRequest(BaseModel):
"""Request body for patching thread metadata."""
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge")
class ThreadStateUpdateRequest(BaseModel):
"""Request body for updating thread state (human-in-the-loop resume)."""
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
as_node: str | None = Field(default=None, description="Node identity for the update")
class HistoryEntry(BaseModel):
"""Single checkpoint history entry."""
checkpoint_id: str
parent_checkpoint_id: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
values: dict[str, Any] = Field(default_factory=dict)
created_at: str | None = None
next: list[str] = Field(default_factory=list)
class ThreadHistoryRequest(BaseModel):
"""Request body for checkpoint history."""
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
before: str | None = Field(default=None, description="Cursor for pagination")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread."""
path_manager = paths or get_paths()
try:
path_manager.delete_thread_dir(thread_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
except FileNotFoundError:
# Not critical — thread data may not exist on disk
logger.debug("No local thread data to delete for %s", thread_id)
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
except Exception as exc:
logger.exception("Failed to delete thread data for %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
logger.info("Deleted local thread data for %s", thread_id)
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
async def _store_get(store, thread_id: str) -> dict | None:
"""Fetch a thread record from the Store; returns ``None`` if absent."""
item = await store.aget(THREADS_NS, thread_id)
return item.value if item is not None else None
async def _store_put(store, record: dict) -> None:
"""Write a thread record to the Store."""
await store.aput(THREADS_NS, record["thread_id"], record)
async def _store_upsert(store, thread_id: str, *, metadata: dict | None = None, values: dict | None = None) -> None:
"""Create or refresh a thread record in the Store.
On creation the record is written with ``status="idle"``. On update only
``updated_at`` (and optionally ``metadata`` / ``values``) are changed so
that existing fields are preserved.
``values`` carries the agent-state snapshot exposed to the frontend
(currently just ``{"title": "..."}``).
"""
now = time.time()
existing = await _store_get(store, thread_id)
if existing is None:
await _store_put(
store,
{
"thread_id": thread_id,
"status": "idle",
"created_at": now,
"updated_at": now,
"metadata": metadata or {},
"values": values or {},
},
)
else:
val = dict(existing)
val["updated_at"] = now
if metadata:
val.setdefault("metadata", {}).update(metadata)
if values:
val.setdefault("values", {}).update(values)
await _store_put(store, val)
def _derive_thread_status(checkpoint_tuple) -> str:
"""Derive thread status from checkpoint metadata."""
if checkpoint_tuple is None:
return "idle"
pending_writes = getattr(checkpoint_tuple, "pending_writes", None) or []
# Check for error in pending writes
for pw in pending_writes:
if len(pw) >= 2 and pw[1] == "__error__":
return "error"
# Check for pending next tasks (indicates interrupt)
tasks = getattr(checkpoint_tuple, "tasks", None)
if tasks:
return "interrupted"
return "idle"
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread.
Cleans DeerFlow-managed thread directories, removes checkpoint data,
and removes the thread record from the Store.
"""
# Clean local filesystem
response = _delete_thread_data(thread_id)
# Remove from Store (best-effort)
store = get_store(request)
if store is not None:
try:
await store.adelete(THREADS_NS, thread_id)
except Exception:
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
# Remove checkpoints (best-effort)
checkpointer = getattr(request.app.state, "checkpointer", None)
if checkpointer is not None:
try:
if hasattr(checkpointer, "adelete_thread"):
await checkpointer.adelete_thread(thread_id)
except Exception:
logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id)
return response
@router.post("", response_model=ThreadResponse)
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
"""Create a new thread.
The thread record is written to the Store (for fast listing) and an
empty checkpoint is written to the checkpointer (for state reads).
Idempotent: returns the existing record when ``thread_id`` already exists.
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
thread_id = body.thread_id or str(uuid.uuid4())
now = time.time()
# Idempotency: return existing record from Store when already present
if store is not None:
existing_record = await _store_get(store, thread_id)
if existing_record is not None:
return ThreadResponse(
thread_id=thread_id,
status=existing_record.get("status", "idle"),
created_at=str(existing_record.get("created_at", "")),
updated_at=str(existing_record.get("updated_at", "")),
metadata=existing_record.get("metadata", {}),
)
# Write thread record to Store
if store is not None:
try:
await _store_put(
store,
{
"thread_id": thread_id,
"status": "idle",
"created_at": now,
"updated_at": now,
"metadata": body.metadata,
},
)
except Exception:
logger.exception("Failed to write thread %s to store", thread_id)
raise HTTPException(status_code=500, detail="Failed to create thread")
# Write an empty checkpoint so state endpoints work immediately
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
from langgraph.checkpoint.base import empty_checkpoint
ckpt_metadata = {
"step": -1,
"source": "input",
"writes": None,
"parents": {},
**body.metadata,
"created_at": now,
}
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
except Exception:
logger.exception("Failed to create checkpoint for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to create thread")
logger.info("Thread created: %s", thread_id)
return ThreadResponse(
thread_id=thread_id,
status="idle",
created_at=str(now),
updated_at=str(now),
metadata=body.metadata,
)
@router.post("/search", response_model=list[ThreadResponse])
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
"""Search and list threads.
Two-phase approach:
**Phase 1 — Store (fast path, O(threads))**: returns threads that were
created or run through this Gateway. Store records are tiny metadata
dicts so fetching all of them at once is cheap.
**Phase 2 — Checkpointer supplement (lazy migration)**: threads that
were created directly by LangGraph Server (and therefore absent from the
Store) are discovered here by iterating the shared checkpointer. Any
newly found thread is immediately written to the Store so that the next
search skips Phase 2 for that thread — the Store converges to a full
index over time without a one-shot migration job.
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
# -----------------------------------------------------------------------
# Phase 1: Store
# -----------------------------------------------------------------------
merged: dict[str, ThreadResponse] = {}
if store is not None:
try:
items = await store.asearch(THREADS_NS, limit=10_000)
except Exception:
logger.warning("Store search failed — falling back to checkpointer only", exc_info=True)
items = []
for item in items:
val = item.value
merged[val["thread_id"]] = ThreadResponse(
thread_id=val["thread_id"],
status=val.get("status", "idle"),
created_at=str(val.get("created_at", "")),
updated_at=str(val.get("updated_at", "")),
metadata=val.get("metadata", {}),
values=val.get("values", {}),
)
# -----------------------------------------------------------------------
# Phase 2: Checkpointer supplement
# Discovers threads not yet in the Store (e.g. created by LangGraph
# Server) and lazily migrates them so future searches skip this phase.
# -----------------------------------------------------------------------
try:
async for checkpoint_tuple in checkpointer.alist(None):
cfg = getattr(checkpoint_tuple, "config", {})
thread_id = cfg.get("configurable", {}).get("thread_id")
if not thread_id or thread_id in merged:
continue
# Skip sub-graph checkpoints (checkpoint_ns is non-empty for those)
if cfg.get("configurable", {}).get("checkpoint_ns", ""):
continue
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
# Strip LangGraph internal keys from the user-visible metadata dict
user_meta = {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
# Extract state values (title) from the checkpoint's channel_values
checkpoint_data = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint_data.get("channel_values", {})
ckpt_values = {}
if title := channel_values.get("title"):
ckpt_values["title"] = title
thread_resp = ThreadResponse(
thread_id=thread_id,
status=_derive_thread_status(checkpoint_tuple),
created_at=str(ckpt_meta.get("created_at", "")),
updated_at=str(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
metadata=user_meta,
values=ckpt_values,
)
merged[thread_id] = thread_resp
# Lazy migration — write to Store so the next search finds it there
if store is not None:
try:
await _store_upsert(store, thread_id, metadata=user_meta, values=ckpt_values or None)
except Exception:
logger.debug("Failed to migrate thread %s to store (non-fatal)", thread_id)
except Exception:
logger.exception("Checkpointer scan failed during thread search")
# Don't raise — return whatever was collected from Store + partial scan
# -----------------------------------------------------------------------
# Phase 3: Filter → sort → paginate
# -----------------------------------------------------------------------
results = list(merged.values())
if body.metadata:
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
if body.status:
results = [r for r in results if r.status == body.status]
results.sort(key=lambda r: r.updated_at, reverse=True)
return results[body.offset : body.offset + body.limit]
@router.patch("/{thread_id}", response_model=ThreadResponse)
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
"""Merge metadata into a thread record."""
store = get_store(request)
if store is None:
raise HTTPException(status_code=503, detail="Store not available")
record = await _store_get(store, thread_id)
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
now = time.time()
updated = dict(record)
updated.setdefault("metadata", {}).update(body.metadata)
updated["updated_at"] = now
try:
await _store_put(store, updated)
except Exception:
logger.exception("Failed to patch thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to update thread")
return ThreadResponse(
thread_id=thread_id,
status=updated.get("status", "idle"),
created_at=str(updated.get("created_at", "")),
updated_at=str(now),
metadata=updated.get("metadata", {}),
)
@router.get("/{thread_id}", response_model=ThreadResponse)
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
"""Get thread info.
Reads metadata from the Store and derives the accurate execution
status from the checkpointer. Falls back to the checkpointer alone
for threads that pre-date Store adoption (backward compat).
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
record: dict | None = None
if store is not None:
record = await _store_get(store, thread_id)
# Derive accurate status from the checkpointer
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get checkpoint for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to get thread")
if record is None and checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# If the thread exists in the checkpointer but not the store (e.g. legacy
# data), synthesize a minimal store record from the checkpoint metadata.
if record is None and checkpoint_tuple is not None:
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
record = {
"thread_id": thread_id,
"status": "idle",
"created_at": ckpt_meta.get("created_at", ""),
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
}
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
channel_values = checkpoint.get("channel_values", {})
return ThreadResponse(
thread_id=thread_id,
status=status,
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
values=serialize_channel_values(channel_values),
)
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
"""Get the latest state snapshot for a thread.
Channel values are serialized to ensure LangChain message objects
are converted to JSON-safe dicts.
"""
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get state for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
checkpoint_id = None
ckpt_config = getattr(checkpoint_tuple, "config", {})
if ckpt_config:
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
channel_values = checkpoint.get("channel_values", {})
parent_config = getattr(checkpoint_tuple, "parent_config", None)
parent_checkpoint_id = None
if parent_config:
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id")
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
return ThreadStateResponse(
values=serialize_channel_values(channel_values),
next=next_tasks,
metadata=metadata,
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_checkpoint_id,
created_at=str(metadata.get("created_at", "")),
tasks=tasks,
)
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
Writes a new checkpoint that merges *body.values* into the latest
channel values, then syncs any updated ``title`` field back to the Store
so that ``/threads/search`` reflects the change immediately.
"""
checkpointer = get_checkpointer(request)
store = get_store(request)
# checkpoint_ns must be present in the config for aput — default to ""
# (the root graph namespace). checkpoint_id is optional; omitting it
# fetches the latest checkpoint for the thread.
read_config: dict[str, Any] = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "",
}
}
if body.checkpoint_id:
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
try:
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
except Exception:
logger.exception("Failed to get state for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# Work on mutable copies so we don't accidentally mutate cached objects.
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
if body.values:
channel_values.update(body.values)
checkpoint["channel_values"] = channel_values
metadata["updated_at"] = time.time()
if body.as_node:
metadata["source"] = "update"
metadata["step"] = metadata.get("step", 0) + 1
metadata["writes"] = {body.as_node: body.values}
# aput requires checkpoint_ns in the config — use the same config used for the
# read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id
# so that aput generates a fresh checkpoint ID for the new snapshot.
write_config: dict[str, Any] = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "",
}
}
try:
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
except Exception:
logger.exception("Failed to update state for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to update thread state")
new_checkpoint_id: str | None = None
if isinstance(new_config, dict):
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
# Sync title changes to the Store so /threads/search reflects them immediately.
if store is not None and body.values and "title" in body.values:
try:
await _store_upsert(store, thread_id, values={"title": body.values["title"]})
except Exception:
logger.debug("Failed to sync title to store for thread %s (non-fatal)", thread_id)
return ThreadStateResponse(
values=serialize_channel_values(channel_values),
next=[],
metadata=metadata,
checkpoint_id=new_checkpoint_id,
created_at=str(metadata.get("created_at", "")),
)
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
"""Get checkpoint history for a thread."""
checkpointer = get_checkpointer(request)
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
if body.before:
config["configurable"]["checkpoint_id"] = body.before
entries: list[HistoryEntry] = []
try:
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
ckpt_config = getattr(checkpoint_tuple, "config", {})
parent_config = getattr(checkpoint_tuple, "parent_config", None)
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
parent_id = None
if parent_config:
parent_id = parent_config.get("configurable", {}).get("checkpoint_id")
channel_values = checkpoint.get("channel_values", {})
# Derive next tasks
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
entries.append(
HistoryEntry(
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_id,
metadata=metadata,
values=serialize_channel_values(channel_values),
created_at=str(metadata.get("created_at", "")),
next=next_tasks,
)
)
except Exception:
logger.exception("Failed to get history for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to get thread history")
return entries

View File

@@ -0,0 +1,168 @@
"""Upload router for handling file uploads."""
import logging
import os
import stat
from fastapi import APIRouter, File, HTTPException, UploadFile
from pydantic import BaseModel
from deerflow.config.paths import get_paths
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.uploads.manager import (
PathTraversalError,
delete_file_safe,
enrich_file_listing,
ensure_uploads_dir,
get_uploads_dir,
list_files_in_dir,
normalize_filename,
upload_artifact_url,
upload_virtual_path,
)
from deerflow.utils.file_conversion import CONVERTIBLE_EXTENSIONS, convert_file_to_markdown
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads/{thread_id}/uploads", tags=["uploads"])
class UploadResponse(BaseModel):
"""Response model for file upload."""
success: bool
files: list[dict[str, str]]
message: str
def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
"""Ensure uploaded files remain writable when mounted into non-local sandboxes.
In AIO sandbox mode, the gateway writes the authoritative host-side file
first, then the sandbox runtime may rewrite the same mounted path. Granting
world-writable access here prevents permission mismatches between the
gateway user and the sandbox runtime user.
"""
file_stat = os.lstat(file_path)
if stat.S_ISLNK(file_stat.st_mode):
logger.warning("Skipping sandbox chmod for symlinked upload path: %s", file_path)
return
writable_mode = stat.S_IMODE(file_stat.st_mode) | stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH
chmod_kwargs = {"follow_symlinks": False} if os.chmod in os.supports_follow_symlinks else {}
os.chmod(file_path, writable_mode, **chmod_kwargs)
@router.post("", response_model=UploadResponse)
async def upload_files(
thread_id: str,
files: list[UploadFile] = File(...),
) -> UploadResponse:
"""Upload multiple files to a thread's uploads directory."""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
try:
uploads_dir = ensure_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
uploaded_files = []
sandbox_provider = get_sandbox_provider()
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
for file in files:
if not file.filename:
continue
try:
safe_filename = normalize_filename(file.filename)
except ValueError:
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
continue
try:
content = await file.read()
file_path = uploads_dir / safe_filename
file_path.write_bytes(content)
virtual_path = upload_virtual_path(safe_filename)
if sandbox_id != "local":
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, content)
file_info = {
"filename": safe_filename,
"size": str(len(content)),
"path": str(sandbox_uploads / safe_filename),
"virtual_path": virtual_path,
"artifact_url": upload_artifact_url(thread_id, safe_filename),
}
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
file_ext = file_path.suffix.lower()
if file_ext in CONVERTIBLE_EXTENSIONS:
md_path = await convert_file_to_markdown(file_path)
if md_path:
md_virtual_path = upload_virtual_path(md_path.name)
if sandbox_id != "local":
_make_file_sandbox_writable(md_path)
sandbox.update_file(md_virtual_path, md_path.read_bytes())
file_info["markdown_file"] = md_path.name
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
file_info["markdown_virtual_path"] = md_virtual_path
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
uploaded_files.append(file_info)
except Exception as e:
logger.error(f"Failed to upload {file.filename}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
return UploadResponse(
success=True,
files=uploaded_files,
message=f"Successfully uploaded {len(uploaded_files)} file(s)",
)
@router.get("/list", response_model=dict)
async def list_uploaded_files(thread_id: str) -> dict:
"""List all files in a thread's uploads directory."""
try:
uploads_dir = get_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
result = list_files_in_dir(uploads_dir)
enrich_file_listing(result, thread_id)
# Gateway additionally includes the sandbox-relative path.
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
for f in result["files"]:
f["path"] = str(sandbox_uploads / f["filename"])
return result
@router.delete("/{filename}")
async def delete_uploaded_file(thread_id: str, filename: str) -> dict:
"""Delete a file from a thread's uploads directory."""
try:
uploads_dir = get_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
try:
return delete_file_safe(uploads_dir, filename, convertible_extensions=CONVERTIBLE_EXTENSIONS)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
except PathTraversalError:
raise HTTPException(status_code=400, detail="Invalid path")
except Exception as e:
logger.error(f"Failed to delete {filename}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to delete {filename}: {str(e)}")

View File

@@ -0,0 +1,367 @@
"""Run lifecycle service layer.
Centralizes the business logic for creating runs, formatting SSE
frames, and consuming stream bridge events. Router modules
(``thread_runs``, ``runs``) are thin HTTP handlers that delegate here.
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import time
from typing import Any
from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
from deerflow.runtime import (
END_SENTINEL,
HEARTBEAT_SENTINEL,
ConflictError,
DisconnectMode,
RunManager,
RunRecord,
RunStatus,
StreamBridge,
UnsupportedStrategyError,
run_agent,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# SSE formatting
# ---------------------------------------------------------------------------
def format_sse(event: str, data: Any, *, event_id: str | None = None) -> str:
"""Format a single SSE frame.
Field order: ``event:`` -> ``data:`` -> ``id:`` (optional) -> blank line.
This matches the LangGraph Platform wire format consumed by the
``useStream`` React hook and the Python ``langgraph-sdk`` SSE decoder.
"""
payload = json.dumps(data, default=str, ensure_ascii=False)
parts = [f"event: {event}", f"data: {payload}"]
if event_id:
parts.append(f"id: {event_id}")
parts.append("")
parts.append("")
return "\n".join(parts)
# ---------------------------------------------------------------------------
# Input / config helpers
# ---------------------------------------------------------------------------
def normalize_stream_modes(raw: list[str] | str | None) -> list[str]:
"""Normalize the stream_mode parameter to a list.
Default matches what ``useStream`` expects: values + messages-tuple.
"""
if raw is None:
return ["values"]
if isinstance(raw, str):
return [raw]
return raw if raw else ["values"]
def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
"""Convert LangGraph Platform input format to LangChain state dict."""
if raw_input is None:
return {}
messages = raw_input.get("messages")
if messages and isinstance(messages, list):
converted = []
for msg in messages:
if isinstance(msg, dict):
role = msg.get("role", msg.get("type", "user"))
content = msg.get("content", "")
if role in ("user", "human"):
converted.append(HumanMessage(content=content))
else:
# TODO: handle other message types (system, ai, tool)
converted.append(HumanMessage(content=content))
else:
converted.append(msg)
return {**raw_input, "messages": converted}
return raw_input
_DEFAULT_ASSISTANT_ID = "lead_agent"
def resolve_agent_factory(assistant_id: str | None):
"""Resolve the agent factory callable from config.
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
injected into ``configurable`` — see :func:`build_run_config`. All
``assistant_id`` values therefore map to the same factory; the routing
happens inside ``make_lead_agent`` when it reads ``cfg["agent_name"]``.
"""
from deerflow.agents.lead_agent.agent import make_lead_agent
return make_lead_agent
def build_run_config(
thread_id: str,
request_config: dict[str, Any] | None,
metadata: dict[str, Any] | None,
*,
assistant_id: str | None = None,
) -> dict[str, Any]:
"""Build a RunnableConfig dict for the agent.
When *assistant_id* refers to a custom agent (anything other than
``"lead_agent"`` / ``None``), the name is forwarded as
``configurable["agent_name"]``. ``make_lead_agent`` reads this key to
load the matching ``agents/<name>/SOUL.md`` and per-agent config —
without it the agent silently runs as the default lead agent.
This mirrors the channel manager's ``_resolve_run_params`` logic so that
the LangGraph Platform-compatible HTTP API and the IM channel path behave
identically.
"""
config: dict[str, Any] = {"recursion_limit": 100}
if request_config:
# LangGraph >= 0.6.0 introduced ``context`` as the preferred way to
# pass thread-level data and rejects requests that include both
# ``configurable`` and ``context``. If the caller already sends
# ``context``, honour it and skip our own ``configurable`` dict.
if "context" in request_config:
if "configurable" in request_config:
logger.warning(
"build_run_config: client sent both 'context' and 'configurable'; preferring 'context' (LangGraph >= 0.6.0). thread_id=%s, caller_configurable keys=%s",
thread_id,
list(request_config.get("configurable", {}).keys()),
)
config["context"] = request_config["context"]
else:
configurable = {"thread_id": thread_id}
configurable.update(request_config.get("configurable", {}))
config["configurable"] = configurable
for k, v in request_config.items():
if k not in ("configurable", "context"):
config[k] = v
else:
config["configurable"] = {"thread_id": thread_id}
# Inject custom agent name when the caller specified a non-default assistant.
# Honour an explicit configurable["agent_name"] in the request if already set.
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "configurable" in config:
if "agent_name" not in config["configurable"]:
normalized = assistant_id.strip().lower().replace("_", "-")
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
config["configurable"]["agent_name"] = normalized
if metadata:
config.setdefault("metadata", {}).update(metadata)
return config
# ---------------------------------------------------------------------------
# Run lifecycle
# ---------------------------------------------------------------------------
async def _upsert_thread_in_store(store, thread_id: str, metadata: dict | None) -> None:
"""Create or refresh the thread record in the Store.
Called from :func:`start_run` so that threads created via the stateless
``/runs/stream`` endpoint (which never calls ``POST /threads``) still
appear in ``/threads/search`` results.
"""
# Deferred import to avoid circular import with the threads router module.
from app.gateway.routers.threads import _store_upsert
try:
await _store_upsert(store, thread_id, metadata=metadata)
except Exception:
logger.warning("Failed to upsert thread %s in store (non-fatal)", thread_id)
async def _sync_thread_title_after_run(
run_task: asyncio.Task,
thread_id: str,
checkpointer: Any,
store: Any,
) -> None:
"""Wait for *run_task* to finish, then persist the generated title to the Store.
TitleMiddleware writes the generated title to the LangGraph agent state
(checkpointer) but the Gateway's Store record is not updated automatically.
This coroutine closes that gap by reading the final checkpoint after the
run completes and syncing ``values.title`` into the Store record so that
subsequent ``/threads/search`` responses include the correct title.
Runs as a fire-and-forget :func:`asyncio.create_task`; failures are
logged at DEBUG level and never propagate.
"""
# Wait for the background run task to complete (any outcome).
# asyncio.wait does not propagate task exceptions — it just returns
# when the task is done, cancelled, or failed.
await asyncio.wait({run_task})
# Deferred import to avoid circular import with the threads router module.
from app.gateway.routers.threads import _store_get, _store_put
try:
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
if ckpt_tuple is None:
return
channel_values = ckpt_tuple.checkpoint.get("channel_values", {})
title = channel_values.get("title")
if not title:
return
existing = await _store_get(store, thread_id)
if existing is None:
return
updated = dict(existing)
updated.setdefault("values", {})["title"] = title
updated["updated_at"] = time.time()
await _store_put(store, updated)
logger.debug("Synced title %r for thread %s", title, thread_id)
except Exception:
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id, exc_info=True)
async def start_run(
body: Any,
thread_id: str,
request: Request,
) -> RunRecord:
"""Create a RunRecord and launch the background agent task.
Parameters
----------
body : RunCreateRequest
The validated request body (typed as Any to avoid circular import
with the router module that defines the Pydantic model).
thread_id : str
Target thread.
request : Request
FastAPI request — used to retrieve singletons from ``app.state``.
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
checkpointer = get_checkpointer(request)
store = get_store(request)
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
try:
record = await run_mgr.create_or_reject(
thread_id,
body.assistant_id,
on_disconnect=disconnect,
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
)
except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except UnsupportedStrategyError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc
# Ensure the thread is visible in /threads/search, even for threads that
# were never explicitly created via POST /threads (e.g. stateless runs).
store = get_store(request)
if store is not None:
await _upsert_thread_in_store(store, thread_id, body.metadata)
agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input)
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
# Merge DeerFlow-specific context overrides into configurable.
# The ``context`` field is a custom extension for the langgraph-compat layer
# that carries agent configuration (model_name, thinking_enabled, etc.).
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
context = getattr(body, "context", None)
if context:
_CONTEXT_CONFIGURABLE_KEYS = {
"model_name",
"mode",
"thinking_enabled",
"reasoning_effort",
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
}
configurable = config.setdefault("configurable", {})
for key in _CONTEXT_CONFIGURABLE_KEYS:
if key in context:
configurable.setdefault(key, context[key])
stream_modes = normalize_stream_modes(body.stream_mode)
task = asyncio.create_task(
run_agent(
bridge,
run_mgr,
record,
checkpointer=checkpointer,
store=store,
agent_factory=agent_factory,
graph_input=graph_input,
config=config,
stream_modes=stream_modes,
stream_subgraphs=body.stream_subgraphs,
interrupt_before=body.interrupt_before,
interrupt_after=body.interrupt_after,
)
)
record.task = task
# After the run completes, sync the title generated by TitleMiddleware from
# the checkpointer into the Store record so that /threads/search returns the
# correct title instead of an empty values dict.
if store is not None:
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store))
return record
async def sse_consumer(
bridge: StreamBridge,
record: RunRecord,
request: Request,
run_mgr: RunManager,
):
"""Async generator that yields SSE frames from the bridge.
The ``finally`` block implements ``on_disconnect`` semantics:
- ``cancel``: abort the background task on client disconnect.
- ``continue``: let the task run; events are discarded.
"""
last_event_id = request.headers.get("Last-Event-ID")
try:
async for entry in bridge.subscribe(record.run_id, last_event_id=last_event_id):
if await request.is_disconnected():
break
if entry is HEARTBEAT_SENTINEL:
yield ": heartbeat\n\n"
continue
if entry is END_SENTINEL:
yield format_sse("end", None, event_id=entry.id or None)
return
yield format_sse(entry.event, entry.data, event_id=entry.id or None)
finally:
if record.status in (RunStatus.pending, RunStatus.running):
if record.on_disconnect == DisconnectMode.cancel:
await run_mgr.cancel(record.run_id)