Initial commit: hardened DeerFlow factory
Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection hardening: - New deerflow.security package: content_delimiter, html_cleaner, sanitizer (8 layers — invisible chars, control chars, symbols, NFC, PUA, tag chars, horizontal whitespace collapse with newline/tab preservation, length cap) - New deerflow.community.searx package: web_search, web_fetch, image_search backed by a private SearX instance, every external string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>> delimiters - All native community web providers (ddg_search, tavily, exa, firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail stubs that raise NativeWebToolDisabledError at import time, so a misconfigured tool.use path fails loud rather than silently falling back to unsanitized output - Native client back-doors (jina_client.py, infoquest_client.py) stubbed too - Native-tool tests quarantined under tests/_disabled_native/ (collect_ignore_glob via local conftest.py) - Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve newlines and tabs so list/table structure survives - Hardened runtime config.yaml references only the searx-backed tools - Factory overlay (backend/) kept in sync with deer-flow tree as a reference / source See HARDENING.md for the full audit trail and verification steps.
This commit is contained in:
4
deer-flow/backend/app/gateway/__init__.py
Normal file
4
deer-flow/backend/app/gateway/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .app import app, create_app
|
||||
from .config import GatewayConfig, get_gateway_config
|
||||
|
||||
__all__ = ["app", "create_app", "GatewayConfig", "get_gateway_config"]
|
||||
221
deer-flow/backend/app/gateway/app.py
Normal file
221
deer-flow/backend/app/gateway/app.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.gateway.config import get_gateway_config
|
||||
from app.gateway.deps import langgraph_runtime
|
||||
from app.gateway.routers import (
|
||||
agents,
|
||||
artifacts,
|
||||
assistants_compat,
|
||||
channels,
|
||||
mcp,
|
||||
memory,
|
||||
models,
|
||||
runs,
|
||||
skills,
|
||||
suggestions,
|
||||
thread_runs,
|
||||
threads,
|
||||
uploads,
|
||||
)
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan handler."""
|
||||
|
||||
# Load config and check necessary environment variables at startup
|
||||
try:
|
||||
get_app_config()
|
||||
logger.info("Configuration loaded successfully")
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
||||
logger.exception(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
config = get_gateway_config()
|
||||
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
|
||||
|
||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
||||
async with langgraph_runtime(app):
|
||||
logger.info("LangGraph runtime initialised")
|
||||
|
||||
# Start IM channel service if any channels are configured
|
||||
try:
|
||||
from app.channels.service import start_channel_service
|
||||
|
||||
channel_service = await start_channel_service()
|
||||
logger.info("Channel service started: %s", channel_service.get_status())
|
||||
except Exception:
|
||||
logger.exception("No IM channels configured or channel service failed to start")
|
||||
|
||||
yield
|
||||
|
||||
# Stop channel service on shutdown
|
||||
try:
|
||||
from app.channels.service import stop_channel_service
|
||||
|
||||
await stop_channel_service()
|
||||
except Exception:
|
||||
logger.exception("Failed to stop channel service")
|
||||
|
||||
logger.info("Shutting down API Gateway")
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
Returns:
|
||||
Configured FastAPI application instance.
|
||||
"""
|
||||
|
||||
app = FastAPI(
|
||||
title="DeerFlow API Gateway",
|
||||
description="""
|
||||
## DeerFlow API Gateway
|
||||
|
||||
API Gateway for DeerFlow - A LangGraph-based AI agent backend with sandbox execution capabilities.
|
||||
|
||||
### Features
|
||||
|
||||
- **Models Management**: Query and retrieve available AI models
|
||||
- **MCP Configuration**: Manage Model Context Protocol (MCP) server configurations
|
||||
- **Memory Management**: Access and manage global memory data for personalized conversations
|
||||
- **Skills Management**: Query and manage skills and their enabled status
|
||||
- **Artifacts**: Access thread artifacts and generated files
|
||||
- **Health Monitoring**: System health check endpoints
|
||||
|
||||
### Architecture
|
||||
|
||||
LangGraph requests are handled by nginx reverse proxy.
|
||||
This gateway provides custom endpoints for models, MCP configuration, skills, and artifacts.
|
||||
""",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
openapi_tags=[
|
||||
{
|
||||
"name": "models",
|
||||
"description": "Operations for querying available AI models and their configurations",
|
||||
},
|
||||
{
|
||||
"name": "mcp",
|
||||
"description": "Manage Model Context Protocol (MCP) server configurations",
|
||||
},
|
||||
{
|
||||
"name": "memory",
|
||||
"description": "Access and manage global memory data for personalized conversations",
|
||||
},
|
||||
{
|
||||
"name": "skills",
|
||||
"description": "Manage skills and their configurations",
|
||||
},
|
||||
{
|
||||
"name": "artifacts",
|
||||
"description": "Access and download thread artifacts and generated files",
|
||||
},
|
||||
{
|
||||
"name": "uploads",
|
||||
"description": "Upload and manage user files for threads",
|
||||
},
|
||||
{
|
||||
"name": "threads",
|
||||
"description": "Manage DeerFlow thread-local filesystem data",
|
||||
},
|
||||
{
|
||||
"name": "agents",
|
||||
"description": "Create and manage custom agents with per-agent config and prompts",
|
||||
},
|
||||
{
|
||||
"name": "suggestions",
|
||||
"description": "Generate follow-up question suggestions for conversations",
|
||||
},
|
||||
{
|
||||
"name": "channels",
|
||||
"description": "Manage IM channel integrations (Feishu, Slack, Telegram)",
|
||||
},
|
||||
{
|
||||
"name": "assistants-compat",
|
||||
"description": "LangGraph Platform-compatible assistants API (stub)",
|
||||
},
|
||||
{
|
||||
"name": "runs",
|
||||
"description": "LangGraph Platform-compatible runs lifecycle (create, stream, cancel)",
|
||||
},
|
||||
{
|
||||
"name": "health",
|
||||
"description": "Health check and system status endpoints",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# CORS is handled by nginx - no need for FastAPI middleware
|
||||
|
||||
# Include routers
|
||||
# Models API is mounted at /api/models
|
||||
app.include_router(models.router)
|
||||
|
||||
# MCP API is mounted at /api/mcp
|
||||
app.include_router(mcp.router)
|
||||
|
||||
# Memory API is mounted at /api/memory
|
||||
app.include_router(memory.router)
|
||||
|
||||
# Skills API is mounted at /api/skills
|
||||
app.include_router(skills.router)
|
||||
|
||||
# Artifacts API is mounted at /api/threads/{thread_id}/artifacts
|
||||
app.include_router(artifacts.router)
|
||||
|
||||
# Uploads API is mounted at /api/threads/{thread_id}/uploads
|
||||
app.include_router(uploads.router)
|
||||
|
||||
# Thread cleanup API is mounted at /api/threads/{thread_id}
|
||||
app.include_router(threads.router)
|
||||
|
||||
# Agents API is mounted at /api/agents
|
||||
app.include_router(agents.router)
|
||||
|
||||
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
||||
app.include_router(suggestions.router)
|
||||
|
||||
# Channels API is mounted at /api/channels
|
||||
app.include_router(channels.router)
|
||||
|
||||
# Assistants compatibility API (LangGraph Platform stub)
|
||||
app.include_router(assistants_compat.router)
|
||||
|
||||
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
|
||||
app.include_router(thread_runs.router)
|
||||
|
||||
# Stateless Runs API (stream/wait without a pre-existing thread)
|
||||
app.include_router(runs.router)
|
||||
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health_check() -> dict:
|
||||
"""Health check endpoint.
|
||||
|
||||
Returns:
|
||||
Service health status information.
|
||||
"""
|
||||
return {"status": "healthy", "service": "deer-flow-gateway"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# Create app instance for uvicorn
|
||||
app = create_app()
|
||||
27
deer-flow/backend/app/gateway/config.py
Normal file
27
deer-flow/backend/app/gateway/config.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GatewayConfig(BaseModel):
|
||||
"""Configuration for the API Gateway."""
|
||||
|
||||
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
|
||||
port: int = Field(default=8001, description="Port to bind the gateway server")
|
||||
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins")
|
||||
|
||||
|
||||
_gateway_config: GatewayConfig | None = None
|
||||
|
||||
|
||||
def get_gateway_config() -> GatewayConfig:
|
||||
"""Get gateway config, loading from environment if available."""
|
||||
global _gateway_config
|
||||
if _gateway_config is None:
|
||||
cors_origins_str = os.getenv("CORS_ORIGINS", "http://localhost:3000")
|
||||
_gateway_config = GatewayConfig(
|
||||
host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
|
||||
port=int(os.getenv("GATEWAY_PORT", "8001")),
|
||||
cors_origins=cors_origins_str.split(","),
|
||||
)
|
||||
return _gateway_config
|
||||
70
deer-flow/backend/app/gateway/deps.py
Normal file
70
deer-flow/backend/app/gateway/deps.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Centralized accessors for singleton objects stored on ``app.state``.
|
||||
|
||||
**Getters** (used by routers): raise 503 when a required dependency is
|
||||
missing, except ``get_store`` which returns ``None``.
|
||||
|
||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
|
||||
from deerflow.runtime import RunManager, StreamBridge
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
||||
|
||||
Usage in ``app.py``::
|
||||
|
||||
async with langgraph_runtime(app):
|
||||
yield
|
||||
"""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||
app.state.store = await stack.enter_async_context(make_store())
|
||||
app.state.run_manager = RunManager()
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Getters – called by routers per-request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_stream_bridge(request: Request) -> StreamBridge:
|
||||
"""Return the global :class:`StreamBridge`, or 503."""
|
||||
bridge = getattr(request.app.state, "stream_bridge", None)
|
||||
if bridge is None:
|
||||
raise HTTPException(status_code=503, detail="Stream bridge not available")
|
||||
return bridge
|
||||
|
||||
|
||||
def get_run_manager(request: Request) -> RunManager:
|
||||
"""Return the global :class:`RunManager`, or 503."""
|
||||
mgr = getattr(request.app.state, "run_manager", None)
|
||||
if mgr is None:
|
||||
raise HTTPException(status_code=503, detail="Run manager not available")
|
||||
return mgr
|
||||
|
||||
|
||||
def get_checkpointer(request: Request):
|
||||
"""Return the global checkpointer, or 503."""
|
||||
cp = getattr(request.app.state, "checkpointer", None)
|
||||
if cp is None:
|
||||
raise HTTPException(status_code=503, detail="Checkpointer not available")
|
||||
return cp
|
||||
|
||||
|
||||
def get_store(request: Request):
|
||||
"""Return the global store (may be ``None`` if not configured)."""
|
||||
return getattr(request.app.state, "store", None)
|
||||
28
deer-flow/backend/app/gateway/path_utils.py
Normal file
28
deer-flow/backend/app/gateway/path_utils.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Shared path resolution for thread virtual paths (e.g. mnt/user-data/outputs/...)."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
|
||||
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
|
||||
"""Resolve a virtual path to the actual filesystem path under thread user-data.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
virtual_path: The virtual path as seen inside the sandbox
|
||||
(e.g., /mnt/user-data/outputs/file.txt).
|
||||
|
||||
Returns:
|
||||
The resolved filesystem path.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the path is invalid or outside allowed directories.
|
||||
"""
|
||||
try:
|
||||
return get_paths().resolve_virtual_path(thread_id, virtual_path)
|
||||
except ValueError as e:
|
||||
status = 403 if "traversal" in str(e) else 400
|
||||
raise HTTPException(status_code=status, detail=str(e))
|
||||
3
deer-flow/backend/app/gateway/routers/__init__.py
Normal file
3
deer-flow/backend/app/gateway/routers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
|
||||
|
||||
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
|
||||
383
deer-flow/backend/app/gateway/routers/agents.py
Normal file
383
deer-flow/backend/app/gateway/routers/agents.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""CRUD API for custom agents."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api", tags=["agents"])
|
||||
|
||||
AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Response model for a custom agent."""
|
||||
|
||||
name: str = Field(..., description="Agent name (hyphen-case)")
|
||||
description: str = Field(default="", description="Agent description")
|
||||
model: str | None = Field(default=None, description="Optional model override")
|
||||
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
|
||||
soul: str | None = Field(default=None, description="SOUL.md content")
|
||||
|
||||
|
||||
class AgentsListResponse(BaseModel):
|
||||
"""Response model for listing all custom agents."""
|
||||
|
||||
agents: list[AgentResponse]
|
||||
|
||||
|
||||
class AgentCreateRequest(BaseModel):
|
||||
"""Request body for creating a custom agent."""
|
||||
|
||||
name: str = Field(..., description="Agent name (must match ^[A-Za-z0-9-]+$, stored as lowercase)")
|
||||
description: str = Field(default="", description="Agent description")
|
||||
model: str | None = Field(default=None, description="Optional model override")
|
||||
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
|
||||
soul: str = Field(default="", description="SOUL.md content — agent personality and behavioral guardrails")
|
||||
|
||||
|
||||
class AgentUpdateRequest(BaseModel):
|
||||
"""Request body for updating a custom agent."""
|
||||
|
||||
description: str | None = Field(default=None, description="Updated description")
|
||||
model: str | None = Field(default=None, description="Updated model override")
|
||||
tool_groups: list[str] | None = Field(default=None, description="Updated tool group whitelist")
|
||||
soul: str | None = Field(default=None, description="Updated SOUL.md content")
|
||||
|
||||
|
||||
def _validate_agent_name(name: str) -> None:
|
||||
"""Validate agent name against allowed pattern.
|
||||
|
||||
Args:
|
||||
name: The agent name to validate.
|
||||
|
||||
Raises:
|
||||
HTTPException: 422 if the name is invalid.
|
||||
"""
|
||||
if not AGENT_NAME_PATTERN.match(name):
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Invalid agent name '{name}'. Must match ^[A-Za-z0-9-]+$ (letters, digits, and hyphens only).",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_agent_name(name: str) -> str:
|
||||
"""Normalize agent name to lowercase for filesystem storage."""
|
||||
return name.lower()
|
||||
|
||||
|
||||
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
|
||||
"""Convert AgentConfig to AgentResponse."""
|
||||
soul: str | None = None
|
||||
if include_soul:
|
||||
soul = load_agent_soul(agent_cfg.name) or ""
|
||||
|
||||
return AgentResponse(
|
||||
name=agent_cfg.name,
|
||||
description=agent_cfg.description,
|
||||
model=agent_cfg.model,
|
||||
tool_groups=agent_cfg.tool_groups,
|
||||
soul=soul,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
response_model=AgentsListResponse,
|
||||
summary="List Custom Agents",
|
||||
description="List all custom agents available in the agents directory, including their soul content.",
|
||||
)
|
||||
async def list_agents() -> AgentsListResponse:
|
||||
"""List all custom agents.
|
||||
|
||||
Returns:
|
||||
List of all custom agents with their metadata and soul content.
|
||||
"""
|
||||
try:
|
||||
agents = list_custom_agents()
|
||||
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list agents: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/check",
|
||||
summary="Check Agent Name",
|
||||
description="Validate an agent name and check if it is available (case-insensitive).",
|
||||
)
|
||||
async def check_agent_name(name: str) -> dict:
|
||||
"""Check whether an agent name is valid and not yet taken.
|
||||
|
||||
Args:
|
||||
name: The agent name to check.
|
||||
|
||||
Returns:
|
||||
``{"available": true/false, "name": "<normalized>"}``
|
||||
|
||||
Raises:
|
||||
HTTPException: 422 if the name is invalid.
|
||||
"""
|
||||
_validate_agent_name(name)
|
||||
normalized = _normalize_agent_name(name)
|
||||
available = not get_paths().agent_dir(normalized).exists()
|
||||
return {"available": available, "name": normalized}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{name}",
|
||||
response_model=AgentResponse,
|
||||
summary="Get Custom Agent",
|
||||
description="Retrieve details and SOUL.md content for a specific custom agent.",
|
||||
)
|
||||
async def get_agent(name: str) -> AgentResponse:
|
||||
"""Get a specific custom agent by name.
|
||||
|
||||
Args:
|
||||
name: The agent name.
|
||||
|
||||
Returns:
|
||||
Agent details including SOUL.md content.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if agent not found.
|
||||
"""
|
||||
_validate_agent_name(name)
|
||||
name = _normalize_agent_name(name)
|
||||
|
||||
try:
|
||||
agent_cfg = load_agent_config(name)
|
||||
return _agent_config_to_response(agent_cfg, include_soul=True)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get agent '{name}': {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get agent: {str(e)}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents",
|
||||
response_model=AgentResponse,
|
||||
status_code=201,
|
||||
summary="Create Custom Agent",
|
||||
description="Create a new custom agent with its config and SOUL.md.",
|
||||
)
|
||||
async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
||||
"""Create a new custom agent.
|
||||
|
||||
Args:
|
||||
request: The agent creation request.
|
||||
|
||||
Returns:
|
||||
The created agent details.
|
||||
|
||||
Raises:
|
||||
HTTPException: 409 if agent already exists, 422 if name is invalid.
|
||||
"""
|
||||
_validate_agent_name(request.name)
|
||||
normalized_name = _normalize_agent_name(request.name)
|
||||
|
||||
agent_dir = get_paths().agent_dir(normalized_name)
|
||||
|
||||
if agent_dir.exists():
|
||||
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
|
||||
|
||||
try:
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write config.yaml
|
||||
config_data: dict = {"name": normalized_name}
|
||||
if request.description:
|
||||
config_data["description"] = request.description
|
||||
if request.model is not None:
|
||||
config_data["model"] = request.model
|
||||
if request.tool_groups is not None:
|
||||
config_data["tool_groups"] = request.tool_groups
|
||||
|
||||
config_file = agent_dir / "config.yaml"
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
# Write SOUL.md
|
||||
soul_file = agent_dir / "SOUL.md"
|
||||
soul_file.write_text(request.soul, encoding="utf-8")
|
||||
|
||||
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
|
||||
|
||||
agent_cfg = load_agent_config(normalized_name)
|
||||
return _agent_config_to_response(agent_cfg, include_soul=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Clean up on failure
|
||||
if agent_dir.exists():
|
||||
shutil.rmtree(agent_dir)
|
||||
logger.error(f"Failed to create agent '{request.name}': {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create agent: {str(e)}")
|
||||
|
||||
|
||||
@router.put(
|
||||
"/agents/{name}",
|
||||
response_model=AgentResponse,
|
||||
summary="Update Custom Agent",
|
||||
description="Update an existing custom agent's config and/or SOUL.md.",
|
||||
)
|
||||
async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
|
||||
"""Update an existing custom agent.
|
||||
|
||||
Args:
|
||||
name: The agent name.
|
||||
request: The update request (all fields optional).
|
||||
|
||||
Returns:
|
||||
The updated agent details.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if agent not found.
|
||||
"""
|
||||
_validate_agent_name(name)
|
||||
name = _normalize_agent_name(name)
|
||||
|
||||
try:
|
||||
agent_cfg = load_agent_config(name)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||
|
||||
agent_dir = get_paths().agent_dir(name)
|
||||
|
||||
try:
|
||||
# Update config if any config fields changed
|
||||
config_changed = any(v is not None for v in [request.description, request.model, request.tool_groups])
|
||||
|
||||
if config_changed:
|
||||
updated: dict = {
|
||||
"name": agent_cfg.name,
|
||||
"description": request.description if request.description is not None else agent_cfg.description,
|
||||
}
|
||||
new_model = request.model if request.model is not None else agent_cfg.model
|
||||
if new_model is not None:
|
||||
updated["model"] = new_model
|
||||
|
||||
new_tool_groups = request.tool_groups if request.tool_groups is not None else agent_cfg.tool_groups
|
||||
if new_tool_groups is not None:
|
||||
updated["tool_groups"] = new_tool_groups
|
||||
|
||||
config_file = agent_dir / "config.yaml"
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
yaml.dump(updated, f, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
# Update SOUL.md if provided
|
||||
if request.soul is not None:
|
||||
soul_path = agent_dir / "SOUL.md"
|
||||
soul_path.write_text(request.soul, encoding="utf-8")
|
||||
|
||||
logger.info(f"Updated agent '{name}'")
|
||||
|
||||
refreshed_cfg = load_agent_config(name)
|
||||
return _agent_config_to_response(refreshed_cfg, include_soul=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update agent '{name}': {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update agent: {str(e)}")
|
||||
|
||||
|
||||
class UserProfileResponse(BaseModel):
|
||||
"""Response model for the global user profile (USER.md)."""
|
||||
|
||||
content: str | None = Field(default=None, description="USER.md content, or null if not yet created")
|
||||
|
||||
|
||||
class UserProfileUpdateRequest(BaseModel):
|
||||
"""Request body for setting the global user profile."""
|
||||
|
||||
content: str = Field(default="", description="USER.md content — describes the user's background and preferences")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/user-profile",
|
||||
response_model=UserProfileResponse,
|
||||
summary="Get User Profile",
|
||||
description="Read the global USER.md file that is injected into all custom agents.",
|
||||
)
|
||||
async def get_user_profile() -> UserProfileResponse:
|
||||
"""Return the current USER.md content.
|
||||
|
||||
Returns:
|
||||
UserProfileResponse with content=None if USER.md does not exist yet.
|
||||
"""
|
||||
try:
|
||||
user_md_path = get_paths().user_md_file
|
||||
if not user_md_path.exists():
|
||||
return UserProfileResponse(content=None)
|
||||
raw = user_md_path.read_text(encoding="utf-8").strip()
|
||||
return UserProfileResponse(content=raw or None)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read user profile: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to read user profile: {str(e)}")
|
||||
|
||||
|
||||
@router.put(
|
||||
"/user-profile",
|
||||
response_model=UserProfileResponse,
|
||||
summary="Update User Profile",
|
||||
description="Write the global USER.md file that is injected into all custom agents.",
|
||||
)
|
||||
async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileResponse:
|
||||
"""Create or overwrite the global USER.md.
|
||||
|
||||
Args:
|
||||
request: The update request with the new USER.md content.
|
||||
|
||||
Returns:
|
||||
UserProfileResponse with the saved content.
|
||||
"""
|
||||
try:
|
||||
paths = get_paths()
|
||||
paths.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
paths.user_md_file.write_text(request.content, encoding="utf-8")
|
||||
logger.info(f"Updated USER.md at {paths.user_md_file}")
|
||||
return UserProfileResponse(content=request.content or None)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update user profile: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update user profile: {str(e)}")
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/agents/{name}",
|
||||
status_code=204,
|
||||
summary="Delete Custom Agent",
|
||||
description="Delete a custom agent and all its files (config, SOUL.md, memory).",
|
||||
)
|
||||
async def delete_agent(name: str) -> None:
|
||||
"""Delete a custom agent.
|
||||
|
||||
Args:
|
||||
name: The agent name.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if agent not found.
|
||||
"""
|
||||
_validate_agent_name(name)
|
||||
name = _normalize_agent_name(name)
|
||||
|
||||
agent_dir = get_paths().agent_dir(name)
|
||||
|
||||
if not agent_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||
|
||||
try:
|
||||
shutil.rmtree(agent_dir)
|
||||
logger.info(f"Deleted agent '{name}' from {agent_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete agent '{name}': {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete agent: {str(e)}")
|
||||
181
deer-flow/backend/app/gateway/routers/artifacts.py
Normal file
181
deer-flow/backend/app/gateway/routers/artifacts.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import FileResponse, PlainTextResponse, Response
|
||||
|
||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["artifacts"])
|
||||
|
||||
ACTIVE_CONTENT_MIME_TYPES = {
|
||||
"text/html",
|
||||
"application/xhtml+xml",
|
||||
"image/svg+xml",
|
||||
}
|
||||
|
||||
|
||||
def _build_content_disposition(disposition_type: str, filename: str) -> str:
|
||||
"""Build an RFC 5987 encoded Content-Disposition header value."""
|
||||
return f"{disposition_type}; filename*=UTF-8''{quote(filename)}"
|
||||
|
||||
|
||||
def _build_attachment_headers(filename: str, extra_headers: dict[str, str] | None = None) -> dict[str, str]:
|
||||
headers = {"Content-Disposition": _build_content_disposition("attachment", filename)}
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
return headers
|
||||
|
||||
|
||||
def is_text_file_by_content(path: Path, sample_size: int = 8192) -> bool:
|
||||
"""Check if file is text by examining content for null bytes."""
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
chunk = f.read(sample_size)
|
||||
# Text files shouldn't contain null bytes
|
||||
return b"\x00" not in chunk
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None:
|
||||
"""Extract a file from a .skill ZIP archive.
|
||||
|
||||
Args:
|
||||
zip_path: Path to the .skill file (ZIP archive).
|
||||
internal_path: Path to the file inside the archive (e.g., "SKILL.md").
|
||||
|
||||
Returns:
|
||||
The file content as bytes, or None if not found.
|
||||
"""
|
||||
if not zipfile.is_zipfile(zip_path):
|
||||
return None
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
# List all files in the archive
|
||||
namelist = zip_ref.namelist()
|
||||
|
||||
# Try direct path first
|
||||
if internal_path in namelist:
|
||||
return zip_ref.read(internal_path)
|
||||
|
||||
# Try with any top-level directory prefix (e.g., "skill-name/SKILL.md")
|
||||
for name in namelist:
|
||||
if name.endswith("/" + internal_path) or name == internal_path:
|
||||
return zip_ref.read(name)
|
||||
|
||||
# Not found
|
||||
return None
|
||||
except (zipfile.BadZipFile, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/threads/{thread_id}/artifacts/{path:path}",
|
||||
summary="Get Artifact File",
|
||||
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
|
||||
)
|
||||
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
|
||||
"""Get an artifact file by its path.
|
||||
|
||||
The endpoint automatically detects file types and returns appropriate content types.
|
||||
Use the `download` query parameter to force file download for non-active content.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
path: The artifact path with virtual prefix (e.g., mnt/user-data/outputs/file.txt).
|
||||
request: FastAPI request object (automatically injected).
|
||||
|
||||
Returns:
|
||||
The file content as a FileResponse with appropriate content type:
|
||||
- Active content (HTML/XHTML/SVG): Served as download attachment
|
||||
- Text files: Plain text with proper MIME type
|
||||
- Binary files: Inline display with download option
|
||||
|
||||
Raises:
|
||||
HTTPException:
|
||||
- 400 if path is invalid or not a file
|
||||
- 403 if access denied (path traversal detected)
|
||||
- 404 if file not found
|
||||
|
||||
Query Parameters:
|
||||
download (bool): If true, forces attachment download for file types that are
|
||||
otherwise returned inline or as plain text. Active HTML/XHTML/SVG content
|
||||
is always downloaded regardless of this flag.
|
||||
|
||||
Example:
|
||||
- Get text file inline: `/api/threads/abc123/artifacts/mnt/user-data/outputs/notes.txt`
|
||||
- Download file: `/api/threads/abc123/artifacts/mnt/user-data/outputs/data.csv?download=true`
|
||||
- Active web content such as `.html`, `.xhtml`, and `.svg` artifacts is always downloaded
|
||||
"""
|
||||
# Check if this is a request for a file inside a .skill archive (e.g., xxx.skill/SKILL.md)
|
||||
if ".skill/" in path:
|
||||
# Split the path at ".skill/" to get the ZIP file path and internal path
|
||||
skill_marker = ".skill/"
|
||||
marker_pos = path.find(skill_marker)
|
||||
skill_file_path = path[: marker_pos + len(".skill")] # e.g., "mnt/user-data/outputs/my-skill.skill"
|
||||
internal_path = path[marker_pos + len(skill_marker) :] # e.g., "SKILL.md"
|
||||
|
||||
actual_skill_path = resolve_thread_virtual_path(thread_id, skill_file_path)
|
||||
|
||||
if not actual_skill_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Skill file not found: {skill_file_path}")
|
||||
|
||||
if not actual_skill_path.is_file():
|
||||
raise HTTPException(status_code=400, detail=f"Path is not a file: {skill_file_path}")
|
||||
|
||||
# Extract the file from the .skill archive
|
||||
content = _extract_file_from_skill_archive(actual_skill_path, internal_path)
|
||||
if content is None:
|
||||
raise HTTPException(status_code=404, detail=f"File '{internal_path}' not found in skill archive")
|
||||
|
||||
# Determine MIME type based on the internal file
|
||||
mime_type, _ = mimetypes.guess_type(internal_path)
|
||||
# Add cache headers to avoid repeated ZIP extraction (cache for 5 minutes)
|
||||
cache_headers = {"Cache-Control": "private, max-age=300"}
|
||||
download_name = Path(internal_path).name or actual_skill_path.stem
|
||||
if download or mime_type in ACTIVE_CONTENT_MIME_TYPES:
|
||||
return Response(content=content, media_type=mime_type or "application/octet-stream", headers=_build_attachment_headers(download_name, cache_headers))
|
||||
|
||||
if mime_type and mime_type.startswith("text/"):
|
||||
return PlainTextResponse(content=content.decode("utf-8"), media_type=mime_type, headers=cache_headers)
|
||||
|
||||
# Default to plain text for unknown types that look like text
|
||||
try:
|
||||
return PlainTextResponse(content=content.decode("utf-8"), media_type="text/plain", headers=cache_headers)
|
||||
except UnicodeDecodeError:
|
||||
return Response(content=content, media_type=mime_type or "application/octet-stream", headers=cache_headers)
|
||||
|
||||
actual_path = resolve_thread_virtual_path(thread_id, path)
|
||||
|
||||
logger.info(f"Resolving artifact path: thread_id={thread_id}, requested_path={path}, actual_path={actual_path}")
|
||||
|
||||
if not actual_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Artifact not found: {path}")
|
||||
|
||||
if not actual_path.is_file():
|
||||
raise HTTPException(status_code=400, detail=f"Path is not a file: {path}")
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(actual_path)
|
||||
|
||||
if download:
|
||||
return FileResponse(path=actual_path, filename=actual_path.name, media_type=mime_type, headers=_build_attachment_headers(actual_path.name))
|
||||
|
||||
# Always force download for active content types to prevent script execution
|
||||
# in the application origin when users open generated artifacts.
|
||||
if mime_type in ACTIVE_CONTENT_MIME_TYPES:
|
||||
return FileResponse(path=actual_path, filename=actual_path.name, media_type=mime_type, headers=_build_attachment_headers(actual_path.name))
|
||||
|
||||
if mime_type and mime_type.startswith("text/"):
|
||||
return PlainTextResponse(content=actual_path.read_text(encoding="utf-8"), media_type=mime_type)
|
||||
|
||||
if is_text_file_by_content(actual_path):
|
||||
return PlainTextResponse(content=actual_path.read_text(encoding="utf-8"), media_type=mime_type)
|
||||
|
||||
return Response(content=actual_path.read_bytes(), media_type=mime_type, headers={"Content-Disposition": _build_content_disposition("inline", actual_path.name)})
|
||||
149
deer-flow/backend/app/gateway/routers/assistants_compat.py
Normal file
149
deer-flow/backend/app/gateway/routers/assistants_compat.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Assistants compatibility endpoints.
|
||||
|
||||
Provides LangGraph Platform-compatible assistants API backed by the
|
||||
``langgraph.json`` graph registry and ``config.yaml`` agent definitions.
|
||||
|
||||
This is a minimal stub that satisfies the ``useStream`` React hook's
|
||||
initialization requirements (``assistants.search()`` and ``assistants.get()``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/assistants", tags=["assistants-compat"])
|
||||
|
||||
|
||||
class AssistantResponse(BaseModel):
|
||||
assistant_id: str
|
||||
graph_id: str
|
||||
name: str
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
description: str | None = None
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
version: int = 1
|
||||
|
||||
|
||||
class AssistantSearchRequest(BaseModel):
|
||||
graph_id: str | None = None
|
||||
name: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
limit: int = 10
|
||||
offset: int = 0
|
||||
|
||||
|
||||
def _get_default_assistant() -> AssistantResponse:
|
||||
"""Return the default lead_agent assistant."""
|
||||
now = datetime.now(UTC).isoformat()
|
||||
return AssistantResponse(
|
||||
assistant_id="lead_agent",
|
||||
graph_id="lead_agent",
|
||||
name="lead_agent",
|
||||
config={},
|
||||
metadata={"created_by": "system"},
|
||||
description="DeerFlow lead agent",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
version=1,
|
||||
)
|
||||
|
||||
|
||||
def _list_assistants() -> list[AssistantResponse]:
|
||||
"""List all available assistants from config."""
|
||||
assistants = [_get_default_assistant()]
|
||||
|
||||
# Also include custom agents from config.yaml agents directory
|
||||
try:
|
||||
from deerflow.config.agents_config import list_custom_agents
|
||||
|
||||
for agent_cfg in list_custom_agents():
|
||||
now = datetime.now(UTC).isoformat()
|
||||
assistants.append(
|
||||
AssistantResponse(
|
||||
assistant_id=agent_cfg.name,
|
||||
graph_id="lead_agent", # All agents use the same graph
|
||||
name=agent_cfg.name,
|
||||
config={},
|
||||
metadata={"created_by": "user"},
|
||||
description=agent_cfg.description or "",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
version=1,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not load custom agents for assistants list")
|
||||
|
||||
return assistants
|
||||
|
||||
|
||||
@router.post("/search", response_model=list[AssistantResponse])
|
||||
async def search_assistants(body: AssistantSearchRequest | None = None) -> list[AssistantResponse]:
|
||||
"""Search assistants.
|
||||
|
||||
Returns all registered assistants (lead_agent + custom agents from config).
|
||||
"""
|
||||
assistants = _list_assistants()
|
||||
|
||||
if body and body.graph_id:
|
||||
assistants = [a for a in assistants if a.graph_id == body.graph_id]
|
||||
if body and body.name:
|
||||
assistants = [a for a in assistants if body.name.lower() in a.name.lower()]
|
||||
|
||||
offset = body.offset if body else 0
|
||||
limit = body.limit if body else 10
|
||||
return assistants[offset : offset + limit]
|
||||
|
||||
|
||||
@router.get("/{assistant_id}", response_model=AssistantResponse)
|
||||
async def get_assistant_compat(assistant_id: str) -> AssistantResponse:
|
||||
"""Get an assistant by ID."""
|
||||
for a in _list_assistants():
|
||||
if a.assistant_id == assistant_id:
|
||||
return a
|
||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
||||
|
||||
|
||||
@router.get("/{assistant_id}/graph")
|
||||
async def get_assistant_graph(assistant_id: str) -> dict:
|
||||
"""Get the graph structure for an assistant.
|
||||
|
||||
Returns a minimal graph description. Full graph introspection is
|
||||
not supported in the Gateway — this stub satisfies SDK validation.
|
||||
"""
|
||||
found = any(a.assistant_id == assistant_id for a in _list_assistants())
|
||||
if not found:
|
||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
||||
|
||||
return {
|
||||
"graph_id": "lead_agent",
|
||||
"nodes": [],
|
||||
"edges": [],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{assistant_id}/schemas")
|
||||
async def get_assistant_schemas(assistant_id: str) -> dict:
|
||||
"""Get JSON schemas for an assistant's input/output/state.
|
||||
|
||||
Returns empty schemas — full introspection not supported in Gateway.
|
||||
"""
|
||||
found = any(a.assistant_id == assistant_id for a in _list_assistants())
|
||||
if not found:
|
||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
||||
|
||||
return {
|
||||
"graph_id": "lead_agent",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
"state_schema": {},
|
||||
"config_schema": {},
|
||||
}
|
||||
52
deer-flow/backend/app/gateway/routers/channels.py
Normal file
52
deer-flow/backend/app/gateway/routers/channels.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Gateway router for IM channel management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/channels", tags=["channels"])
|
||||
|
||||
|
||||
class ChannelStatusResponse(BaseModel):
|
||||
service_running: bool
|
||||
channels: dict[str, dict]
|
||||
|
||||
|
||||
class ChannelRestartResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@router.get("/", response_model=ChannelStatusResponse)
|
||||
async def get_channels_status() -> ChannelStatusResponse:
|
||||
"""Get the status of all IM channels."""
|
||||
from app.channels.service import get_channel_service
|
||||
|
||||
service = get_channel_service()
|
||||
if service is None:
|
||||
return ChannelStatusResponse(service_running=False, channels={})
|
||||
status = service.get_status()
|
||||
return ChannelStatusResponse(**status)
|
||||
|
||||
|
||||
@router.post("/{name}/restart", response_model=ChannelRestartResponse)
|
||||
async def restart_channel(name: str) -> ChannelRestartResponse:
|
||||
"""Restart a specific IM channel."""
|
||||
from app.channels.service import get_channel_service
|
||||
|
||||
service = get_channel_service()
|
||||
if service is None:
|
||||
raise HTTPException(status_code=503, detail="Channel service is not running")
|
||||
|
||||
success = await service.restart_channel(name)
|
||||
if success:
|
||||
logger.info("Channel %s restarted successfully", name)
|
||||
return ChannelRestartResponse(success=True, message=f"Channel {name} restarted successfully")
|
||||
else:
|
||||
logger.warning("Failed to restart channel %s", name)
|
||||
return ChannelRestartResponse(success=False, message=f"Failed to restart channel {name}")
|
||||
169
deer-flow/backend/app/gateway/routers/mcp.py
Normal file
169
deer-flow/backend/app/gateway/routers/mcp.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api", tags=["mcp"])
|
||||
|
||||
|
||||
class McpOAuthConfigResponse(BaseModel):
|
||||
"""OAuth configuration for an MCP server."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
|
||||
token_url: str = Field(default="", description="OAuth token endpoint URL")
|
||||
grant_type: Literal["client_credentials", "refresh_token"] = Field(default="client_credentials", description="OAuth grant type")
|
||||
client_id: str | None = Field(default=None, description="OAuth client ID")
|
||||
client_secret: str | None = Field(default=None, description="OAuth client secret")
|
||||
refresh_token: str | None = Field(default=None, description="OAuth refresh token")
|
||||
scope: str | None = Field(default=None, description="OAuth scope")
|
||||
audience: str | None = Field(default=None, description="OAuth audience")
|
||||
token_field: str = Field(default="access_token", description="Token response field containing access token")
|
||||
token_type_field: str = Field(default="token_type", description="Token response field containing token type")
|
||||
expires_in_field: str = Field(default="expires_in", description="Token response field containing expires-in seconds")
|
||||
default_token_type: str = Field(default="Bearer", description="Default token type when response omits token_type")
|
||||
refresh_skew_seconds: int = Field(default=60, description="Refresh this many seconds before expiry")
|
||||
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
|
||||
|
||||
|
||||
class McpServerConfigResponse(BaseModel):
|
||||
"""Response model for MCP server configuration."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
|
||||
type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'")
|
||||
command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)")
|
||||
args: list[str] = Field(default_factory=list, description="Arguments to pass to the command (for stdio type)")
|
||||
env: dict[str, str] = Field(default_factory=dict, description="Environment variables for the MCP server")
|
||||
url: str | None = Field(default=None, description="URL of the MCP server (for sse or http type)")
|
||||
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
|
||||
oauth: McpOAuthConfigResponse | None = Field(default=None, description="OAuth configuration for MCP HTTP/SSE servers")
|
||||
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
|
||||
|
||||
|
||||
class McpConfigResponse(BaseModel):
|
||||
"""Response model for MCP configuration."""
|
||||
|
||||
mcp_servers: dict[str, McpServerConfigResponse] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of MCP server name to configuration",
|
||||
)
|
||||
|
||||
|
||||
class McpConfigUpdateRequest(BaseModel):
|
||||
"""Request model for updating MCP configuration."""
|
||||
|
||||
mcp_servers: dict[str, McpServerConfigResponse] = Field(
|
||||
...,
|
||||
description="Map of MCP server name to configuration",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/mcp/config",
|
||||
response_model=McpConfigResponse,
|
||||
summary="Get MCP Configuration",
|
||||
description="Retrieve the current Model Context Protocol (MCP) server configurations.",
|
||||
)
|
||||
async def get_mcp_configuration() -> McpConfigResponse:
|
||||
"""Get the current MCP configuration.
|
||||
|
||||
Returns:
|
||||
The current MCP configuration with all servers.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"mcp_servers": {
|
||||
"github": {
|
||||
"enabled": true,
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
||||
"env": {"GITHUB_TOKEN": "ghp_xxx"},
|
||||
"description": "GitHub MCP server for repository operations"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
config = get_extensions_config()
|
||||
|
||||
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in config.mcp_servers.items()})
|
||||
|
||||
|
||||
@router.put(
|
||||
"/mcp/config",
|
||||
response_model=McpConfigResponse,
|
||||
summary="Update MCP Configuration",
|
||||
description="Update Model Context Protocol (MCP) server configurations and save to file.",
|
||||
)
|
||||
async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfigResponse:
|
||||
"""Update the MCP configuration.
|
||||
|
||||
This will:
|
||||
1. Save the new configuration to the mcp_config.json file
|
||||
2. Reload the configuration cache
|
||||
3. Reset MCP tools cache to trigger reinitialization
|
||||
|
||||
Args:
|
||||
request: The new MCP configuration to save.
|
||||
|
||||
Returns:
|
||||
The updated MCP configuration.
|
||||
|
||||
Raises:
|
||||
HTTPException: 500 if the configuration file cannot be written.
|
||||
|
||||
Example Request:
|
||||
```json
|
||||
{
|
||||
"mcp_servers": {
|
||||
"github": {
|
||||
"enabled": true,
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
||||
"env": {"GITHUB_TOKEN": "$GITHUB_TOKEN"},
|
||||
"description": "GitHub MCP server for repository operations"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
try:
|
||||
# Get the current config path (or determine where to save it)
|
||||
config_path = ExtensionsConfig.resolve_config_path()
|
||||
|
||||
# If no config file exists, create one in the parent directory (project root)
|
||||
if config_path is None:
|
||||
config_path = Path.cwd().parent / "extensions_config.json"
|
||||
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
|
||||
|
||||
# Load current config to preserve skills configuration
|
||||
current_config = get_extensions_config()
|
||||
|
||||
# Convert request to dict format for JSON serialization
|
||||
config_data = {
|
||||
"mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()},
|
||||
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
|
||||
}
|
||||
|
||||
# Write the configuration to file
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
|
||||
logger.info(f"MCP configuration updated and saved to: {config_path}")
|
||||
|
||||
# NOTE: No need to reload/reset cache here - LangGraph Server (separate process)
|
||||
# will detect config file changes via mtime and reinitialize MCP tools automatically
|
||||
|
||||
# Reload the configuration and update the global cache
|
||||
reloaded_config = reload_extensions_config()
|
||||
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_config.mcp_servers.items()})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update MCP configuration: {str(e)}")
|
||||
353
deer-flow/backend/app/gateway/routers/memory.py
Normal file
353
deer-flow/backend/app/gateway/routers/memory.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""Memory API router for retrieving and managing global memory data."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.agents.memory.updater import (
|
||||
clear_memory_data,
|
||||
create_memory_fact,
|
||||
delete_memory_fact,
|
||||
get_memory_data,
|
||||
import_memory_data,
|
||||
reload_memory_data,
|
||||
update_memory_fact,
|
||||
)
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["memory"])
|
||||
|
||||
|
||||
class ContextSection(BaseModel):
|
||||
"""Model for context sections (user and history)."""
|
||||
|
||||
summary: str = Field(default="", description="Summary content")
|
||||
updatedAt: str = Field(default="", description="Last update timestamp")
|
||||
|
||||
|
||||
class UserContext(BaseModel):
|
||||
"""Model for user context."""
|
||||
|
||||
workContext: ContextSection = Field(default_factory=ContextSection)
|
||||
personalContext: ContextSection = Field(default_factory=ContextSection)
|
||||
topOfMind: ContextSection = Field(default_factory=ContextSection)
|
||||
|
||||
|
||||
class HistoryContext(BaseModel):
|
||||
"""Model for history context."""
|
||||
|
||||
recentMonths: ContextSection = Field(default_factory=ContextSection)
|
||||
earlierContext: ContextSection = Field(default_factory=ContextSection)
|
||||
longTermBackground: ContextSection = Field(default_factory=ContextSection)
|
||||
|
||||
|
||||
class Fact(BaseModel):
|
||||
"""Model for a memory fact."""
|
||||
|
||||
id: str = Field(..., description="Unique identifier for the fact")
|
||||
content: str = Field(..., description="Fact content")
|
||||
category: str = Field(default="context", description="Fact category")
|
||||
confidence: float = Field(default=0.5, description="Confidence score (0-1)")
|
||||
createdAt: str = Field(default="", description="Creation timestamp")
|
||||
source: str = Field(default="unknown", description="Source thread ID")
|
||||
sourceError: str | None = Field(default=None, description="Optional description of the prior mistake or wrong approach")
|
||||
|
||||
|
||||
class MemoryResponse(BaseModel):
|
||||
"""Response model for memory data."""
|
||||
|
||||
version: str = Field(default="1.0", description="Memory schema version")
|
||||
lastUpdated: str = Field(default="", description="Last update timestamp")
|
||||
user: UserContext = Field(default_factory=UserContext)
|
||||
history: HistoryContext = Field(default_factory=HistoryContext)
|
||||
facts: list[Fact] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _map_memory_fact_value_error(exc: ValueError) -> HTTPException:
|
||||
"""Convert updater validation errors into stable API responses."""
|
||||
if exc.args and exc.args[0] == "confidence":
|
||||
detail = "Invalid confidence value; must be between 0 and 1."
|
||||
else:
|
||||
detail = "Memory fact content cannot be empty."
|
||||
return HTTPException(status_code=400, detail=detail)
|
||||
|
||||
|
||||
class FactCreateRequest(BaseModel):
|
||||
"""Request model for creating a memory fact."""
|
||||
|
||||
content: str = Field(..., min_length=1, description="Fact content")
|
||||
category: str = Field(default="context", description="Fact category")
|
||||
confidence: float = Field(default=0.5, ge=0.0, le=1.0, description="Confidence score (0-1)")
|
||||
|
||||
|
||||
class FactPatchRequest(BaseModel):
|
||||
"""PATCH request model that preserves existing values for omitted fields."""
|
||||
|
||||
content: str | None = Field(default=None, min_length=1, description="Fact content")
|
||||
category: str | None = Field(default=None, description="Fact category")
|
||||
confidence: float | None = Field(default=None, ge=0.0, le=1.0, description="Confidence score (0-1)")
|
||||
|
||||
|
||||
class MemoryConfigResponse(BaseModel):
|
||||
"""Response model for memory configuration."""
|
||||
|
||||
enabled: bool = Field(..., description="Whether memory is enabled")
|
||||
storage_path: str = Field(..., description="Path to memory storage file")
|
||||
debounce_seconds: int = Field(..., description="Debounce time for memory updates")
|
||||
max_facts: int = Field(..., description="Maximum number of facts to store")
|
||||
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
|
||||
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
|
||||
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
|
||||
|
||||
|
||||
class MemoryStatusResponse(BaseModel):
|
||||
"""Response model for memory status."""
|
||||
|
||||
config: MemoryConfigResponse
|
||||
data: MemoryResponse
|
||||
|
||||
|
||||
@router.get(
|
||||
"/memory",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Get Memory Data",
|
||||
description="Retrieve the current global memory data including user context, history, and facts.",
|
||||
)
|
||||
async def get_memory() -> MemoryResponse:
|
||||
"""Get the current global memory data.
|
||||
|
||||
Returns:
|
||||
The current memory data with user context, history, and facts.
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"version": "1.0",
|
||||
"lastUpdated": "2024-01-15T10:30:00Z",
|
||||
"user": {
|
||||
"workContext": {"summary": "Working on DeerFlow project", "updatedAt": "..."},
|
||||
"personalContext": {"summary": "Prefers concise responses", "updatedAt": "..."},
|
||||
"topOfMind": {"summary": "Building memory API", "updatedAt": "..."}
|
||||
},
|
||||
"history": {
|
||||
"recentMonths": {"summary": "Recent development activities", "updatedAt": "..."},
|
||||
"earlierContext": {"summary": "", "updatedAt": ""},
|
||||
"longTermBackground": {"summary": "", "updatedAt": ""}
|
||||
},
|
||||
"facts": [
|
||||
{
|
||||
"id": "fact_abc123",
|
||||
"content": "User prefers TypeScript over JavaScript",
|
||||
"category": "preference",
|
||||
"confidence": 0.9,
|
||||
"createdAt": "2024-01-15T10:30:00Z",
|
||||
"source": "thread_xyz"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
"""
|
||||
memory_data = get_memory_data()
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/memory/reload",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Reload Memory Data",
|
||||
description="Reload memory data from the storage file, refreshing the in-memory cache.",
|
||||
)
|
||||
async def reload_memory() -> MemoryResponse:
|
||||
"""Reload memory data from file.
|
||||
|
||||
This forces a reload of the memory data from the storage file,
|
||||
useful when the file has been modified externally.
|
||||
|
||||
Returns:
|
||||
The reloaded memory data.
|
||||
"""
|
||||
memory_data = reload_memory_data()
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/memory",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Clear All Memory Data",
|
||||
description="Delete all saved memory data and reset the memory structure to an empty state.",
|
||||
)
|
||||
async def clear_memory() -> MemoryResponse:
|
||||
"""Clear all persisted memory data."""
|
||||
try:
|
||||
memory_data = clear_memory_data()
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
|
||||
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/memory/facts",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Create Memory Fact",
|
||||
description="Create a single saved memory fact manually.",
|
||||
)
|
||||
async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryResponse:
|
||||
"""Create a single fact manually."""
|
||||
try:
|
||||
memory_data = create_memory_fact(
|
||||
content=request.content,
|
||||
category=request.category,
|
||||
confidence=request.confidence,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise _map_memory_fact_value_error(exc) from exc
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
|
||||
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/memory/facts/{fact_id}",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Delete Memory Fact",
|
||||
description="Delete a single saved memory fact by its fact id.",
|
||||
)
|
||||
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
|
||||
"""Delete a single fact from memory by fact id."""
|
||||
try:
|
||||
memory_data = delete_memory_fact(fact_id)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
|
||||
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/memory/facts/{fact_id}",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Patch Memory Fact",
|
||||
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
|
||||
)
|
||||
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -> MemoryResponse:
|
||||
"""Partially update a single fact manually."""
|
||||
try:
|
||||
memory_data = update_memory_fact(
|
||||
fact_id=fact_id,
|
||||
content=request.content,
|
||||
category=request.category,
|
||||
confidence=request.confidence,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise _map_memory_fact_value_error(exc) from exc
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
|
||||
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/memory/export",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Export Memory Data",
|
||||
description="Export the current global memory data as JSON for backup or transfer.",
|
||||
)
|
||||
async def export_memory() -> MemoryResponse:
|
||||
"""Export the current memory data."""
|
||||
memory_data = get_memory_data()
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/memory/import",
|
||||
response_model=MemoryResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Import Memory Data",
|
||||
description="Import and overwrite the current global memory data from a JSON payload.",
|
||||
)
|
||||
async def import_memory(request: MemoryResponse) -> MemoryResponse:
|
||||
"""Import and persist memory data."""
|
||||
try:
|
||||
memory_data = import_memory_data(request.model_dump())
|
||||
except OSError as exc:
|
||||
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
|
||||
|
||||
return MemoryResponse(**memory_data)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/memory/config",
|
||||
response_model=MemoryConfigResponse,
|
||||
summary="Get Memory Configuration",
|
||||
description="Retrieve the current memory system configuration.",
|
||||
)
|
||||
async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
||||
"""Get the memory system configuration.
|
||||
|
||||
Returns:
|
||||
The current memory configuration settings.
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"enabled": true,
|
||||
"storage_path": ".deer-flow/memory.json",
|
||||
"debounce_seconds": 30,
|
||||
"max_facts": 100,
|
||||
"fact_confidence_threshold": 0.7,
|
||||
"injection_enabled": true,
|
||||
"max_injection_tokens": 2000
|
||||
}
|
||||
```
|
||||
"""
|
||||
config = get_memory_config()
|
||||
return MemoryConfigResponse(
|
||||
enabled=config.enabled,
|
||||
storage_path=config.storage_path,
|
||||
debounce_seconds=config.debounce_seconds,
|
||||
max_facts=config.max_facts,
|
||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||
injection_enabled=config.injection_enabled,
|
||||
max_injection_tokens=config.max_injection_tokens,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/memory/status",
|
||||
response_model=MemoryStatusResponse,
|
||||
response_model_exclude_none=True,
|
||||
summary="Get Memory Status",
|
||||
description="Retrieve both memory configuration and current data in a single request.",
|
||||
)
|
||||
async def get_memory_status() -> MemoryStatusResponse:
|
||||
"""Get the memory system status including configuration and data.
|
||||
|
||||
Returns:
|
||||
Combined memory configuration and current data.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
memory_data = get_memory_data()
|
||||
|
||||
return MemoryStatusResponse(
|
||||
config=MemoryConfigResponse(
|
||||
enabled=config.enabled,
|
||||
storage_path=config.storage_path,
|
||||
debounce_seconds=config.debounce_seconds,
|
||||
max_facts=config.max_facts,
|
||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||
injection_enabled=config.injection_enabled,
|
||||
max_injection_tokens=config.max_injection_tokens,
|
||||
),
|
||||
data=MemoryResponse(**memory_data),
|
||||
)
|
||||
116
deer-flow/backend/app/gateway/routers/models.py
Normal file
116
deer-flow/backend/app/gateway/routers/models.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["models"])
|
||||
|
||||
|
||||
class ModelResponse(BaseModel):
|
||||
"""Response model for model information."""
|
||||
|
||||
name: str = Field(..., description="Unique identifier for the model")
|
||||
model: str = Field(..., description="Actual provider model identifier")
|
||||
display_name: str | None = Field(None, description="Human-readable name")
|
||||
description: str | None = Field(None, description="Model description")
|
||||
supports_thinking: bool = Field(default=False, description="Whether model supports thinking mode")
|
||||
supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort")
|
||||
|
||||
|
||||
class ModelsListResponse(BaseModel):
|
||||
"""Response model for listing all models."""
|
||||
|
||||
models: list[ModelResponse]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
response_model=ModelsListResponse,
|
||||
summary="List All Models",
|
||||
description="Retrieve a list of all available AI models configured in the system.",
|
||||
)
|
||||
async def list_models() -> ModelsListResponse:
|
||||
"""List all available models from configuration.
|
||||
|
||||
Returns model information suitable for frontend display,
|
||||
excluding sensitive fields like API keys and internal configuration.
|
||||
|
||||
Returns:
|
||||
A list of all configured models with their metadata.
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"display_name": "GPT-4",
|
||||
"description": "OpenAI GPT-4 model",
|
||||
"supports_thinking": false
|
||||
},
|
||||
{
|
||||
"name": "claude-3-opus",
|
||||
"display_name": "Claude 3 Opus",
|
||||
"description": "Anthropic Claude 3 Opus model",
|
||||
"supports_thinking": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
"""
|
||||
config = get_app_config()
|
||||
models = [
|
||||
ModelResponse(
|
||||
name=model.name,
|
||||
model=model.model,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
supports_thinking=model.supports_thinking,
|
||||
supports_reasoning_effort=model.supports_reasoning_effort,
|
||||
)
|
||||
for model in config.models
|
||||
]
|
||||
return ModelsListResponse(models=models)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models/{model_name}",
|
||||
response_model=ModelResponse,
|
||||
summary="Get Model Details",
|
||||
description="Retrieve detailed information about a specific AI model by its name.",
|
||||
)
|
||||
async def get_model(model_name: str) -> ModelResponse:
|
||||
"""Get a specific model by name.
|
||||
|
||||
Args:
|
||||
model_name: The unique name of the model to retrieve.
|
||||
|
||||
Returns:
|
||||
Model information if found.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if model not found.
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"display_name": "GPT-4",
|
||||
"description": "OpenAI GPT-4 model",
|
||||
"supports_thinking": false
|
||||
}
|
||||
```
|
||||
"""
|
||||
config = get_app_config()
|
||||
model = config.get_model_config(model_name)
|
||||
if model is None:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
return ModelResponse(
|
||||
name=model.name,
|
||||
model=model.model,
|
||||
display_name=model.display_name,
|
||||
description=model.description,
|
||||
supports_thinking=model.supports_thinking,
|
||||
supports_reasoning_effort=model.supports_reasoning_effort,
|
||||
)
|
||||
87
deer-flow/backend/app/gateway/routers/runs.py
Normal file
87
deer-flow/backend/app/gateway/routers/runs.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Stateless runs endpoints -- stream and wait without a pre-existing thread.
|
||||
|
||||
These endpoints auto-create a temporary thread when no ``thread_id`` is
|
||||
supplied in the request body. When a ``thread_id`` **is** provided, it
|
||||
is reused so that conversation history is preserved across calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/runs", tags=["runs"])
|
||||
|
||||
|
||||
def _resolve_thread_id(body: RunCreateRequest) -> str:
|
||||
"""Return the thread_id from the request body, or generate a new one."""
|
||||
thread_id = (body.config or {}).get("configurable", {}).get("thread_id")
|
||||
if thread_id:
|
||||
return str(thread_id)
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@router.post("/stream")
|
||||
async def stateless_stream(body: RunCreateRequest, request: Request) -> StreamingResponse:
|
||||
"""Create a run and stream events via SSE.
|
||||
|
||||
If ``config.configurable.thread_id`` is provided, the run is created
|
||||
on the given thread so that conversation history is preserved.
|
||||
Otherwise a new temporary thread is created.
|
||||
"""
|
||||
thread_id = _resolve_thread_id(body)
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
return StreamingResponse(
|
||||
sse_consumer(bridge, record, request, run_mgr),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/wait", response_model=dict)
|
||||
async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
||||
"""Create a run and block until completion.
|
||||
|
||||
If ``config.configurable.thread_id`` is provided, the run is created
|
||||
on the given thread so that conversation history is preserved.
|
||||
Otherwise a new temporary thread is created.
|
||||
"""
|
||||
thread_id = _resolve_thread_id(body)
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
if record.task is not None:
|
||||
try:
|
||||
await record.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
return {"status": record.status.value, "error": record.error}
|
||||
356
deer-flow/backend/app/gateway/routers/skills.py
Normal file
356
deer-flow/backend/app/gateway/routers/skills.py
Normal file
@@ -0,0 +1,356 @@
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.path_utils import resolve_thread_virtual_path
|
||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
|
||||
from deerflow.skills import Skill, load_skills
|
||||
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
|
||||
from deerflow.skills.manager import (
|
||||
append_history,
|
||||
atomic_write,
|
||||
custom_skill_exists,
|
||||
ensure_custom_skill_is_editable,
|
||||
get_custom_skill_dir,
|
||||
get_custom_skill_file,
|
||||
get_skill_history_file,
|
||||
read_custom_skill_content,
|
||||
read_history,
|
||||
validate_skill_markdown_content,
|
||||
)
|
||||
from deerflow.skills.security_scanner import scan_skill_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["skills"])
|
||||
|
||||
|
||||
class SkillResponse(BaseModel):
|
||||
"""Response model for skill information."""
|
||||
|
||||
name: str = Field(..., description="Name of the skill")
|
||||
description: str = Field(..., description="Description of what the skill does")
|
||||
license: str | None = Field(None, description="License information")
|
||||
category: str = Field(..., description="Category of the skill (public or custom)")
|
||||
enabled: bool = Field(default=True, description="Whether this skill is enabled")
|
||||
|
||||
|
||||
class SkillsListResponse(BaseModel):
|
||||
"""Response model for listing all skills."""
|
||||
|
||||
skills: list[SkillResponse]
|
||||
|
||||
|
||||
class SkillUpdateRequest(BaseModel):
|
||||
"""Request model for updating a skill."""
|
||||
|
||||
enabled: bool = Field(..., description="Whether to enable or disable the skill")
|
||||
|
||||
|
||||
class SkillInstallRequest(BaseModel):
|
||||
"""Request model for installing a skill from a .skill file."""
|
||||
|
||||
thread_id: str = Field(..., description="The thread ID where the .skill file is located")
|
||||
path: str = Field(..., description="Virtual path to the .skill file (e.g., mnt/user-data/outputs/my-skill.skill)")
|
||||
|
||||
|
||||
class SkillInstallResponse(BaseModel):
|
||||
"""Response model for skill installation."""
|
||||
|
||||
success: bool = Field(..., description="Whether the installation was successful")
|
||||
skill_name: str = Field(..., description="Name of the installed skill")
|
||||
message: str = Field(..., description="Installation result message")
|
||||
|
||||
|
||||
class CustomSkillContentResponse(SkillResponse):
|
||||
content: str = Field(..., description="Raw SKILL.md content")
|
||||
|
||||
|
||||
class CustomSkillUpdateRequest(BaseModel):
|
||||
content: str = Field(..., description="Replacement SKILL.md content")
|
||||
|
||||
|
||||
class CustomSkillHistoryResponse(BaseModel):
|
||||
history: list[dict]
|
||||
|
||||
|
||||
class SkillRollbackRequest(BaseModel):
|
||||
history_index: int = Field(default=-1, description="History entry index to restore from, defaulting to the latest change.")
|
||||
|
||||
|
||||
def _skill_to_response(skill: Skill) -> SkillResponse:
|
||||
"""Convert a Skill object to a SkillResponse."""
|
||||
return SkillResponse(
|
||||
name=skill.name,
|
||||
description=skill.description,
|
||||
license=skill.license,
|
||||
category=skill.category,
|
||||
enabled=skill.enabled,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/skills",
|
||||
response_model=SkillsListResponse,
|
||||
summary="List All Skills",
|
||||
description="Retrieve a list of all available skills from both public and custom directories.",
|
||||
)
|
||||
async def list_skills() -> SkillsListResponse:
|
||||
try:
|
||||
skills = load_skills(enabled_only=False)
|
||||
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load skills: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load skills: {str(e)}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/skills/install",
|
||||
response_model=SkillInstallResponse,
|
||||
summary="Install Skill",
|
||||
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
|
||||
)
|
||||
async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
|
||||
try:
|
||||
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
|
||||
result = install_skill_from_archive(skill_file_path)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
return SkillInstallResponse(**result)
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except SkillAlreadyExistsError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to install skill: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to install skill: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
|
||||
async def list_custom_skills() -> SkillsListResponse:
|
||||
try:
|
||||
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
|
||||
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
|
||||
except Exception as e:
|
||||
logger.error("Failed to list custom skills: %s", e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list custom skills: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
|
||||
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
|
||||
try:
|
||||
skills = load_skills(enabled_only=False)
|
||||
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
|
||||
if skill is None:
|
||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to get custom skill %s: %s", skill_name, e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get custom skill: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
|
||||
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse:
|
||||
try:
|
||||
ensure_custom_skill_is_editable(skill_name)
|
||||
validate_skill_markdown_content(skill_name, request.content)
|
||||
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md")
|
||||
if scan.decision == "block":
|
||||
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
|
||||
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md"
|
||||
prev_content = skill_file.read_text(encoding="utf-8")
|
||||
atomic_write(skill_file, request.content)
|
||||
append_history(
|
||||
skill_name,
|
||||
{
|
||||
"action": "human_edit",
|
||||
"author": "human",
|
||||
"thread_id": None,
|
||||
"file_path": "SKILL.md",
|
||||
"prev_content": prev_content,
|
||||
"new_content": request.content,
|
||||
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
||||
},
|
||||
)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
return await get_custom_skill(skill_name)
|
||||
except HTTPException:
|
||||
raise
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error("Failed to update custom skill %s: %s", skill_name, e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update custom skill: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
|
||||
async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
|
||||
try:
|
||||
ensure_custom_skill_is_editable(skill_name)
|
||||
skill_dir = get_custom_skill_dir(skill_name)
|
||||
prev_content = read_custom_skill_content(skill_name)
|
||||
append_history(
|
||||
skill_name,
|
||||
{
|
||||
"action": "human_delete",
|
||||
"author": "human",
|
||||
"thread_id": None,
|
||||
"file_path": "SKILL.md",
|
||||
"prev_content": prev_content,
|
||||
"new_content": None,
|
||||
"scanner": {"decision": "allow", "reason": "Deletion requested."},
|
||||
},
|
||||
)
|
||||
shutil.rmtree(skill_dir)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
return {"success": True}
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error("Failed to delete custom skill %s: %s", skill_name, e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete custom skill: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
|
||||
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse:
|
||||
try:
|
||||
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
|
||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||
return CustomSkillHistoryResponse(history=read_history(skill_name))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to read history for %s: %s", skill_name, e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to read history: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
|
||||
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse:
|
||||
try:
|
||||
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
|
||||
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
|
||||
history = read_history(skill_name)
|
||||
if not history:
|
||||
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
|
||||
record = history[request.history_index]
|
||||
target_content = record.get("prev_content")
|
||||
if target_content is None:
|
||||
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
|
||||
validate_skill_markdown_content(skill_name, target_content)
|
||||
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md")
|
||||
skill_file = get_custom_skill_file(skill_name)
|
||||
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
|
||||
history_entry = {
|
||||
"action": "rollback",
|
||||
"author": "human",
|
||||
"thread_id": None,
|
||||
"file_path": "SKILL.md",
|
||||
"prev_content": current_content,
|
||||
"new_content": target_content,
|
||||
"rollback_from_ts": record.get("ts"),
|
||||
"scanner": {"decision": scan.decision, "reason": scan.reason},
|
||||
}
|
||||
if scan.decision == "block":
|
||||
append_history(skill_name, history_entry)
|
||||
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
|
||||
atomic_write(skill_file, target_content)
|
||||
append_history(skill_name, history_entry)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
return await get_custom_skill(skill_name)
|
||||
except HTTPException:
|
||||
raise
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=400, detail="history_index is out of range")
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error("Failed to roll back custom skill %s: %s", skill_name, e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to roll back custom skill: {str(e)}")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/skills/{skill_name}",
|
||||
response_model=SkillResponse,
|
||||
summary="Get Skill Details",
|
||||
description="Retrieve detailed information about a specific skill by its name.",
|
||||
)
|
||||
async def get_skill(skill_name: str) -> SkillResponse:
|
||||
try:
|
||||
skills = load_skills(enabled_only=False)
|
||||
skill = next((s for s in skills if s.name == skill_name), None)
|
||||
|
||||
if skill is None:
|
||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||
|
||||
return _skill_to_response(skill)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get skill {skill_name}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get skill: {str(e)}")
|
||||
|
||||
|
||||
@router.put(
|
||||
"/skills/{skill_name}",
|
||||
response_model=SkillResponse,
|
||||
summary="Update Skill",
|
||||
description="Update a skill's enabled status by modifying the extensions_config.json file.",
|
||||
)
|
||||
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse:
|
||||
try:
|
||||
skills = load_skills(enabled_only=False)
|
||||
skill = next((s for s in skills if s.name == skill_name), None)
|
||||
|
||||
if skill is None:
|
||||
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
|
||||
|
||||
config_path = ExtensionsConfig.resolve_config_path()
|
||||
if config_path is None:
|
||||
config_path = Path.cwd().parent / "extensions_config.json"
|
||||
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
|
||||
|
||||
extensions_config = get_extensions_config()
|
||||
extensions_config.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
|
||||
|
||||
config_data = {
|
||||
"mcpServers": {name: server.model_dump() for name, server in extensions_config.mcp_servers.items()},
|
||||
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in extensions_config.skills.items()},
|
||||
}
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
|
||||
logger.info(f"Skills configuration updated and saved to: {config_path}")
|
||||
reload_extensions_config()
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
|
||||
skills = load_skills(enabled_only=False)
|
||||
updated_skill = next((s for s in skills if s.name == skill_name), None)
|
||||
|
||||
if updated_skill is None:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to reload skill '{skill_name}' after update")
|
||||
|
||||
logger.info(f"Skill '{skill_name}' enabled status updated to {request.enabled}")
|
||||
return _skill_to_response(updated_skill)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update skill {skill_name}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update skill: {str(e)}")
|
||||
132
deer-flow/backend/app/gateway/routers/suggestions.py
Normal file
132
deer-flow/backend/app/gateway/routers/suggestions.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["suggestions"])
|
||||
|
||||
|
||||
class SuggestionMessage(BaseModel):
|
||||
role: str = Field(..., description="Message role: user|assistant")
|
||||
content: str = Field(..., description="Message content as plain text")
|
||||
|
||||
|
||||
class SuggestionsRequest(BaseModel):
|
||||
messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages")
|
||||
n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate")
|
||||
model_name: str | None = Field(default=None, description="Optional model override")
|
||||
|
||||
|
||||
class SuggestionsResponse(BaseModel):
|
||||
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
|
||||
|
||||
|
||||
def _strip_markdown_code_fence(text: str) -> str:
|
||||
stripped = text.strip()
|
||||
if not stripped.startswith("```"):
|
||||
return stripped
|
||||
lines = stripped.splitlines()
|
||||
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
|
||||
return "\n".join(lines[1:-1]).strip()
|
||||
return stripped
|
||||
|
||||
|
||||
def _parse_json_string_list(text: str) -> list[str] | None:
|
||||
candidate = _strip_markdown_code_fence(text)
|
||||
start = candidate.find("[")
|
||||
end = candidate.rfind("]")
|
||||
if start == -1 or end == -1 or end <= start:
|
||||
return None
|
||||
candidate = candidate[start : end + 1]
|
||||
try:
|
||||
data = json.loads(candidate)
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(data, list):
|
||||
return None
|
||||
out: list[str] = []
|
||||
for item in data:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
s = item.strip()
|
||||
if not s:
|
||||
continue
|
||||
out.append(s)
|
||||
return out
|
||||
|
||||
|
||||
def _extract_response_text(content: object) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
parts.append(block)
|
||||
elif isinstance(block, dict) and block.get("type") in {"text", "output_text"}:
|
||||
text = block.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "\n".join(parts) if parts else ""
|
||||
if content is None:
|
||||
return ""
|
||||
return str(content)
|
||||
|
||||
|
||||
def _format_conversation(messages: list[SuggestionMessage]) -> str:
|
||||
parts: list[str] = []
|
||||
for m in messages:
|
||||
role = m.role.strip().lower()
|
||||
if role in ("user", "human"):
|
||||
parts.append(f"User: {m.content.strip()}")
|
||||
elif role in ("assistant", "ai"):
|
||||
parts.append(f"Assistant: {m.content.strip()}")
|
||||
else:
|
||||
parts.append(f"{m.role}: {m.content.strip()}")
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/threads/{thread_id}/suggestions",
|
||||
response_model=SuggestionsResponse,
|
||||
summary="Generate Follow-up Questions",
|
||||
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
|
||||
)
|
||||
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
|
||||
if not request.messages:
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
|
||||
n = request.n
|
||||
conversation = _format_conversation(request.messages)
|
||||
if not conversation:
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
|
||||
system_instruction = (
|
||||
"You are generating follow-up questions to help the user continue the conversation.\n"
|
||||
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
|
||||
"Requirements:\n"
|
||||
"- Questions must be relevant to the preceding conversation.\n"
|
||||
"- Questions must be written in the same language as the user.\n"
|
||||
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
|
||||
"- Do NOT include numbering, markdown, or any extra text.\n"
|
||||
"- Output MUST be a JSON array of strings only.\n"
|
||||
)
|
||||
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||
|
||||
try:
|
||||
model = create_chat_model(name=request.model_name, thinking_enabled=False)
|
||||
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
||||
raw = _extract_response_text(response.content)
|
||||
suggestions = _parse_json_string_list(raw) or []
|
||||
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
|
||||
cleaned = cleaned[:n]
|
||||
return SuggestionsResponse(suggestions=cleaned)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc)
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
267
deer-flow/backend/app/gateway/routers/thread_runs.py
Normal file
267
deer-flow/backend/app/gateway/routers/thread_runs.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Runs endpoints — create, stream, wait, cancel.
|
||||
|
||||
Implements the LangGraph Platform runs API on top of
|
||||
:class:`deerflow.agents.runs.RunManager` and
|
||||
:class:`deerflow.agents.stream_bridge.StreamBridge`.
|
||||
|
||||
SSE format is aligned with the LangGraph Platform protocol so that
|
||||
the ``useStream`` React hook from ``@langchain/langgraph-sdk/react``
|
||||
works without modification.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RunCreateRequest(BaseModel):
|
||||
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
|
||||
input: dict[str, Any] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
|
||||
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
|
||||
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
|
||||
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
|
||||
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
||||
webhook: str | None = Field(default=None, description="Completion callback URL")
|
||||
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
||||
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
|
||||
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
|
||||
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
|
||||
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
|
||||
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
|
||||
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
|
||||
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
|
||||
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
|
||||
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
||||
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
||||
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None = None
|
||||
status: str
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
multitask_strategy: str = "reject"
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _record_to_response(record: RunRecord) -> RunResponse:
|
||||
return RunResponse(
|
||||
run_id=record.run_id,
|
||||
thread_id=record.thread_id,
|
||||
assistant_id=record.assistant_id,
|
||||
status=record.status.value,
|
||||
metadata=record.metadata,
|
||||
kwargs=record.kwargs,
|
||||
multitask_strategy=record.multitask_strategy,
|
||||
created_at=record.created_at,
|
||||
updated_at=record.updated_at,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
||||
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
|
||||
"""Create a background run (returns immediately)."""
|
||||
record = await start_run(body, thread_id, request)
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/stream")
|
||||
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
|
||||
"""Create a run and stream events via SSE.
|
||||
|
||||
The response includes a ``Content-Location`` header with the run's
|
||||
resource URL, matching the LangGraph Platform protocol. The
|
||||
``useStream`` React hook uses this to extract run metadata.
|
||||
"""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
return StreamingResponse(
|
||||
sse_consumer(bridge, record, request, run_mgr),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
# LangGraph Platform includes run metadata in this header.
|
||||
# The SDK uses a greedy regex to extract the run id from this path,
|
||||
# so it must point at the canonical run resource without extra suffixes.
|
||||
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/wait", response_model=dict)
|
||||
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
|
||||
"""Create a run and block until it completes, returning the final state."""
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
if record.task is not None:
|
||||
try:
|
||||
await record.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
return {"status": record.status.value, "error": record.error}
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
||||
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
||||
"""List all runs for a thread."""
|
||||
run_mgr = get_run_manager(request)
|
||||
records = await run_mgr.list_by_thread(thread_id)
|
||||
return [_record_to_response(r) for r in records]
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
||||
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||
"""Get details of a specific run."""
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
||||
async def cancel_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
wait: bool = Query(default=False, description="Block until run completes after cancel"),
|
||||
action: Literal["interrupt", "rollback"] = Query(default="interrupt", description="Cancel action"),
|
||||
) -> Response:
|
||||
"""Cancel a running or pending run.
|
||||
|
||||
- action=interrupt: Stop execution, keep current checkpoint (can be resumed)
|
||||
- action=rollback: Stop execution, revert to pre-run checkpoint state
|
||||
- wait=true: Block until the run fully stops, return 204
|
||||
- wait=false: Return immediately with 202
|
||||
"""
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
||||
if not cancelled:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Run {run_id} is not cancellable (status: {record.status.value})",
|
||||
)
|
||||
|
||||
if wait and record.task is not None:
|
||||
try:
|
||||
await record.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
return Response(status_code=204)
|
||||
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/join")
|
||||
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
||||
"""Join an existing run's SSE stream."""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
return StreamingResponse(
|
||||
sse_consumer(bridge, record, request, run_mgr),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
|
||||
async def stream_existing_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
action: Literal["interrupt", "rollback"] | None = Query(default=None, description="Cancel action"),
|
||||
wait: int = Query(default=0, description="Block until cancelled (1) or return immediately (0)"),
|
||||
):
|
||||
"""Join an existing run's SSE stream (GET), or cancel-then-stream (POST).
|
||||
|
||||
The LangGraph SDK's ``joinStream`` and ``useStream`` stop button both use
|
||||
``POST`` to this endpoint. When ``action=interrupt`` or ``action=rollback``
|
||||
is present the run is cancelled first; the response then streams any
|
||||
remaining buffered events so the client observes a clean shutdown.
|
||||
"""
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
|
||||
# Cancel if an action was requested (stop-button / interrupt flow)
|
||||
if action is not None:
|
||||
cancelled = await run_mgr.cancel(run_id, action=action)
|
||||
if cancelled and wait and record.task is not None:
|
||||
try:
|
||||
await record.task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
return Response(status_code=204)
|
||||
|
||||
bridge = get_stream_bridge(request)
|
||||
return StreamingResponse(
|
||||
sse_consumer(bridge, record, request, run_mgr),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
682
deer-flow/backend/app/gateway/routers/threads.py
Normal file
682
deer-flow/backend/app/gateway/routers/threads.py
Normal file
@@ -0,0 +1,682 @@
|
||||
"""Thread CRUD, state, and history endpoints.
|
||||
|
||||
Combines the existing thread-local filesystem cleanup with LangGraph
|
||||
Platform-compatible thread management backed by the checkpointer.
|
||||
|
||||
Channel values returned in state responses are serialized through
|
||||
:func:`deerflow.runtime.serialization.serialize_channel_values` to
|
||||
ensure LangChain message objects are converted to JSON-safe dicts
|
||||
matching the LangGraph Platform wire format expected by the
|
||||
``useStream`` React hook.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_store
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Store namespace
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
THREADS_NS: tuple[str, ...] = ("threads",)
|
||||
"""Namespace used by the Store for thread metadata records."""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response / request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ThreadDeleteResponse(BaseModel):
|
||||
"""Response model for thread cleanup."""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class ThreadResponse(BaseModel):
|
||||
"""Response model for a single thread."""
|
||||
|
||||
thread_id: str = Field(description="Unique thread identifier")
|
||||
status: str = Field(default="idle", description="Thread status: idle, busy, interrupted, error")
|
||||
created_at: str = Field(default="", description="ISO timestamp")
|
||||
updated_at: str = Field(default="", description="ISO timestamp")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
|
||||
values: dict[str, Any] = Field(default_factory=dict, description="Current state channel values")
|
||||
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
|
||||
|
||||
|
||||
class ThreadCreateRequest(BaseModel):
|
||||
"""Request body for creating a thread."""
|
||||
|
||||
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
||||
|
||||
|
||||
class ThreadSearchRequest(BaseModel):
|
||||
"""Request body for searching threads."""
|
||||
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
|
||||
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
|
||||
offset: int = Field(default=0, ge=0, description="Pagination offset")
|
||||
status: str | None = Field(default=None, description="Filter by thread status")
|
||||
|
||||
|
||||
class ThreadStateResponse(BaseModel):
|
||||
"""Response model for thread state."""
|
||||
|
||||
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
|
||||
next: list[str] = Field(default_factory=list, description="Next tasks to execute")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
|
||||
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
|
||||
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
|
||||
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
|
||||
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
|
||||
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
|
||||
|
||||
|
||||
class ThreadPatchRequest(BaseModel):
|
||||
"""Request body for patching thread metadata."""
|
||||
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge")
|
||||
|
||||
|
||||
class ThreadStateUpdateRequest(BaseModel):
|
||||
"""Request body for updating thread state (human-in-the-loop resume)."""
|
||||
|
||||
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
|
||||
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
|
||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
||||
as_node: str | None = Field(default=None, description="Node identity for the update")
|
||||
|
||||
|
||||
class HistoryEntry(BaseModel):
|
||||
"""Single checkpoint history entry."""
|
||||
|
||||
checkpoint_id: str
|
||||
parent_checkpoint_id: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
values: dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: str | None = None
|
||||
next: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ThreadHistoryRequest(BaseModel):
|
||||
"""Request body for checkpoint history."""
|
||||
|
||||
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
|
||||
before: str | None = Field(default=None, description="Cursor for pagination")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
|
||||
"""Delete local persisted filesystem data for a thread."""
|
||||
path_manager = paths or get_paths()
|
||||
try:
|
||||
path_manager.delete_thread_dir(thread_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
except FileNotFoundError:
|
||||
# Not critical — thread data may not exist on disk
|
||||
logger.debug("No local thread data to delete for %s", thread_id)
|
||||
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete thread data for %s", thread_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
|
||||
|
||||
logger.info("Deleted local thread data for %s", thread_id)
|
||||
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
|
||||
|
||||
|
||||
async def _store_get(store, thread_id: str) -> dict | None:
|
||||
"""Fetch a thread record from the Store; returns ``None`` if absent."""
|
||||
item = await store.aget(THREADS_NS, thread_id)
|
||||
return item.value if item is not None else None
|
||||
|
||||
|
||||
async def _store_put(store, record: dict) -> None:
|
||||
"""Write a thread record to the Store."""
|
||||
await store.aput(THREADS_NS, record["thread_id"], record)
|
||||
|
||||
|
||||
async def _store_upsert(store, thread_id: str, *, metadata: dict | None = None, values: dict | None = None) -> None:
|
||||
"""Create or refresh a thread record in the Store.
|
||||
|
||||
On creation the record is written with ``status="idle"``. On update only
|
||||
``updated_at`` (and optionally ``metadata`` / ``values``) are changed so
|
||||
that existing fields are preserved.
|
||||
|
||||
``values`` carries the agent-state snapshot exposed to the frontend
|
||||
(currently just ``{"title": "..."}``).
|
||||
"""
|
||||
now = time.time()
|
||||
existing = await _store_get(store, thread_id)
|
||||
if existing is None:
|
||||
await _store_put(
|
||||
store,
|
||||
{
|
||||
"thread_id": thread_id,
|
||||
"status": "idle",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"metadata": metadata or {},
|
||||
"values": values or {},
|
||||
},
|
||||
)
|
||||
else:
|
||||
val = dict(existing)
|
||||
val["updated_at"] = now
|
||||
if metadata:
|
||||
val.setdefault("metadata", {}).update(metadata)
|
||||
if values:
|
||||
val.setdefault("values", {}).update(values)
|
||||
await _store_put(store, val)
|
||||
|
||||
|
||||
def _derive_thread_status(checkpoint_tuple) -> str:
|
||||
"""Derive thread status from checkpoint metadata."""
|
||||
if checkpoint_tuple is None:
|
||||
return "idle"
|
||||
pending_writes = getattr(checkpoint_tuple, "pending_writes", None) or []
|
||||
|
||||
# Check for error in pending writes
|
||||
for pw in pending_writes:
|
||||
if len(pw) >= 2 and pw[1] == "__error__":
|
||||
return "error"
|
||||
|
||||
# Check for pending next tasks (indicates interrupt)
|
||||
tasks = getattr(checkpoint_tuple, "tasks", None)
|
||||
if tasks:
|
||||
return "interrupted"
|
||||
|
||||
return "idle"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
||||
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
|
||||
"""Delete local persisted filesystem data for a thread.
|
||||
|
||||
Cleans DeerFlow-managed thread directories, removes checkpoint data,
|
||||
and removes the thread record from the Store.
|
||||
"""
|
||||
# Clean local filesystem
|
||||
response = _delete_thread_data(thread_id)
|
||||
|
||||
# Remove from Store (best-effort)
|
||||
store = get_store(request)
|
||||
if store is not None:
|
||||
try:
|
||||
await store.adelete(THREADS_NS, thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
|
||||
|
||||
# Remove checkpoints (best-effort)
|
||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
||||
if checkpointer is not None:
|
||||
try:
|
||||
if hasattr(checkpointer, "adelete_thread"):
|
||||
await checkpointer.adelete_thread(thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post("", response_model=ThreadResponse)
|
||||
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
|
||||
"""Create a new thread.
|
||||
|
||||
The thread record is written to the Store (for fast listing) and an
|
||||
empty checkpoint is written to the checkpointer (for state reads).
|
||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = time.time()
|
||||
|
||||
# Idempotency: return existing record from Store when already present
|
||||
if store is not None:
|
||||
existing_record = await _store_get(store, thread_id)
|
||||
if existing_record is not None:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=existing_record.get("status", "idle"),
|
||||
created_at=str(existing_record.get("created_at", "")),
|
||||
updated_at=str(existing_record.get("updated_at", "")),
|
||||
metadata=existing_record.get("metadata", {}),
|
||||
)
|
||||
|
||||
# Write thread record to Store
|
||||
if store is not None:
|
||||
try:
|
||||
await _store_put(
|
||||
store,
|
||||
{
|
||||
"thread_id": thread_id,
|
||||
"status": "idle",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"metadata": body.metadata,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to write thread %s to store", thread_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
|
||||
# Write an empty checkpoint so state endpoints work immediately
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
try:
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
ckpt_metadata = {
|
||||
"step": -1,
|
||||
"source": "input",
|
||||
"writes": None,
|
||||
"parents": {},
|
||||
**body.metadata,
|
||||
"created_at": now,
|
||||
}
|
||||
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
|
||||
except Exception:
|
||||
logger.exception("Failed to create checkpoint for thread %s", thread_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
|
||||
logger.info("Thread created: %s", thread_id)
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status="idle",
|
||||
created_at=str(now),
|
||||
updated_at=str(now),
|
||||
metadata=body.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/search", response_model=list[ThreadResponse])
|
||||
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
|
||||
"""Search and list threads.
|
||||
|
||||
Two-phase approach:
|
||||
|
||||
**Phase 1 — Store (fast path, O(threads))**: returns threads that were
|
||||
created or run through this Gateway. Store records are tiny metadata
|
||||
dicts so fetching all of them at once is cheap.
|
||||
|
||||
**Phase 2 — Checkpointer supplement (lazy migration)**: threads that
|
||||
were created directly by LangGraph Server (and therefore absent from the
|
||||
Store) are discovered here by iterating the shared checkpointer. Any
|
||||
newly found thread is immediately written to the Store so that the next
|
||||
search skips Phase 2 for that thread — the Store converges to a full
|
||||
index over time without a one-shot migration job.
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 1: Store
|
||||
# -----------------------------------------------------------------------
|
||||
merged: dict[str, ThreadResponse] = {}
|
||||
|
||||
if store is not None:
|
||||
try:
|
||||
items = await store.asearch(THREADS_NS, limit=10_000)
|
||||
except Exception:
|
||||
logger.warning("Store search failed — falling back to checkpointer only", exc_info=True)
|
||||
items = []
|
||||
|
||||
for item in items:
|
||||
val = item.value
|
||||
merged[val["thread_id"]] = ThreadResponse(
|
||||
thread_id=val["thread_id"],
|
||||
status=val.get("status", "idle"),
|
||||
created_at=str(val.get("created_at", "")),
|
||||
updated_at=str(val.get("updated_at", "")),
|
||||
metadata=val.get("metadata", {}),
|
||||
values=val.get("values", {}),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 2: Checkpointer supplement
|
||||
# Discovers threads not yet in the Store (e.g. created by LangGraph
|
||||
# Server) and lazily migrates them so future searches skip this phase.
|
||||
# -----------------------------------------------------------------------
|
||||
try:
|
||||
async for checkpoint_tuple in checkpointer.alist(None):
|
||||
cfg = getattr(checkpoint_tuple, "config", {})
|
||||
thread_id = cfg.get("configurable", {}).get("thread_id")
|
||||
if not thread_id or thread_id in merged:
|
||||
continue
|
||||
|
||||
# Skip sub-graph checkpoints (checkpoint_ns is non-empty for those)
|
||||
if cfg.get("configurable", {}).get("checkpoint_ns", ""):
|
||||
continue
|
||||
|
||||
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
# Strip LangGraph internal keys from the user-visible metadata dict
|
||||
user_meta = {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
|
||||
|
||||
# Extract state values (title) from the checkpoint's channel_values
|
||||
checkpoint_data = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint_data.get("channel_values", {})
|
||||
ckpt_values = {}
|
||||
if title := channel_values.get("title"):
|
||||
ckpt_values["title"] = title
|
||||
|
||||
thread_resp = ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=_derive_thread_status(checkpoint_tuple),
|
||||
created_at=str(ckpt_meta.get("created_at", "")),
|
||||
updated_at=str(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
|
||||
metadata=user_meta,
|
||||
values=ckpt_values,
|
||||
)
|
||||
merged[thread_id] = thread_resp
|
||||
|
||||
# Lazy migration — write to Store so the next search finds it there
|
||||
if store is not None:
|
||||
try:
|
||||
await _store_upsert(store, thread_id, metadata=user_meta, values=ckpt_values or None)
|
||||
except Exception:
|
||||
logger.debug("Failed to migrate thread %s to store (non-fatal)", thread_id)
|
||||
except Exception:
|
||||
logger.exception("Checkpointer scan failed during thread search")
|
||||
# Don't raise — return whatever was collected from Store + partial scan
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 3: Filter → sort → paginate
|
||||
# -----------------------------------------------------------------------
|
||||
results = list(merged.values())
|
||||
|
||||
if body.metadata:
|
||||
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
|
||||
|
||||
if body.status:
|
||||
results = [r for r in results if r.status == body.status]
|
||||
|
||||
results.sort(key=lambda r: r.updated_at, reverse=True)
|
||||
return results[body.offset : body.offset + body.limit]
|
||||
|
||||
|
||||
@router.patch("/{thread_id}", response_model=ThreadResponse)
|
||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
||||
"""Merge metadata into a thread record."""
|
||||
store = get_store(request)
|
||||
if store is None:
|
||||
raise HTTPException(status_code=503, detail="Store not available")
|
||||
|
||||
record = await _store_get(store, thread_id)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
now = time.time()
|
||||
updated = dict(record)
|
||||
updated.setdefault("metadata", {}).update(body.metadata)
|
||||
updated["updated_at"] = now
|
||||
|
||||
try:
|
||||
await _store_put(store, updated)
|
||||
except Exception:
|
||||
logger.exception("Failed to patch thread %s", thread_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to update thread")
|
||||
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=updated.get("status", "idle"),
|
||||
created_at=str(updated.get("created_at", "")),
|
||||
updated_at=str(now),
|
||||
metadata=updated.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}", response_model=ThreadResponse)
|
||||
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
"""Get thread info.
|
||||
|
||||
Reads metadata from the Store and derives the accurate execution
|
||||
status from the checkpointer. Falls back to the checkpointer alone
|
||||
for threads that pre-date Store adoption (backward compat).
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
record: dict | None = None
|
||||
if store is not None:
|
||||
record = await _store_get(store, thread_id)
|
||||
|
||||
# Derive accurate status from the checkpointer
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get checkpoint for thread %s", thread_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread")
|
||||
|
||||
if record is None and checkpoint_tuple is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
# If the thread exists in the checkpointer but not the store (e.g. legacy
|
||||
# data), synthesize a minimal store record from the checkpoint metadata.
|
||||
if record is None and checkpoint_tuple is not None:
|
||||
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"status": "idle",
|
||||
"created_at": ckpt_meta.get("created_at", ""),
|
||||
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
|
||||
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
||||
}
|
||||
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=status,
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
values=serialize_channel_values(channel_values),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
||||
"""Get the latest state snapshot for a thread.
|
||||
|
||||
Channel values are serialized to ensure LangChain message objects
|
||||
are converted to JSON-safe dicts.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get state for thread %s", thread_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||
|
||||
if checkpoint_tuple is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
checkpoint_id = None
|
||||
ckpt_config = getattr(checkpoint_tuple, "config", {})
|
||||
if ckpt_config:
|
||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
|
||||
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||
parent_checkpoint_id = None
|
||||
if parent_config:
|
||||
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id")
|
||||
|
||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=serialize_channel_values(channel_values),
|
||||
next=next_tasks,
|
||||
metadata=metadata,
|
||||
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||
checkpoint_id=checkpoint_id,
|
||||
parent_checkpoint_id=parent_checkpoint_id,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
tasks=tasks,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
|
||||
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
|
||||
|
||||
Writes a new checkpoint that merges *body.values* into the latest
|
||||
channel values, then syncs any updated ``title`` field back to the Store
|
||||
so that ``/threads/search`` reflects the change immediately.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
store = get_store(request)
|
||||
|
||||
# checkpoint_ns must be present in the config for aput — default to ""
|
||||
# (the root graph namespace). checkpoint_id is optional; omitting it
|
||||
# fetches the latest checkpoint for the thread.
|
||||
read_config: dict[str, Any] = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": "",
|
||||
}
|
||||
}
|
||||
if body.checkpoint_id:
|
||||
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
|
||||
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get state for thread %s", thread_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||
|
||||
if checkpoint_tuple is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
# Work on mutable copies so we don't accidentally mutate cached objects.
|
||||
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
|
||||
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
|
||||
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
|
||||
|
||||
if body.values:
|
||||
channel_values.update(body.values)
|
||||
|
||||
checkpoint["channel_values"] = channel_values
|
||||
metadata["updated_at"] = time.time()
|
||||
|
||||
if body.as_node:
|
||||
metadata["source"] = "update"
|
||||
metadata["step"] = metadata.get("step", 0) + 1
|
||||
metadata["writes"] = {body.as_node: body.values}
|
||||
|
||||
# aput requires checkpoint_ns in the config — use the same config used for the
|
||||
# read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id
|
||||
# so that aput generates a fresh checkpoint ID for the new snapshot.
|
||||
write_config: dict[str, Any] = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
"checkpoint_ns": "",
|
||||
}
|
||||
}
|
||||
try:
|
||||
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
|
||||
except Exception:
|
||||
logger.exception("Failed to update state for thread %s", thread_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to update thread state")
|
||||
|
||||
new_checkpoint_id: str | None = None
|
||||
if isinstance(new_config, dict):
|
||||
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
|
||||
|
||||
# Sync title changes to the Store so /threads/search reflects them immediately.
|
||||
if store is not None and body.values and "title" in body.values:
|
||||
try:
|
||||
await _store_upsert(store, thread_id, values={"title": body.values["title"]})
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title to store for thread %s (non-fatal)", thread_id)
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=serialize_channel_values(channel_values),
|
||||
next=[],
|
||||
metadata=metadata,
|
||||
checkpoint_id=new_checkpoint_id,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
||||
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
|
||||
"""Get checkpoint history for a thread."""
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
|
||||
if body.before:
|
||||
config["configurable"]["checkpoint_id"] = body.before
|
||||
|
||||
entries: list[HistoryEntry] = []
|
||||
try:
|
||||
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
|
||||
ckpt_config = getattr(checkpoint_tuple, "config", {})
|
||||
parent_config = getattr(checkpoint_tuple, "parent_config", None)
|
||||
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
|
||||
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
|
||||
parent_id = None
|
||||
if parent_config:
|
||||
parent_id = parent_config.get("configurable", {}).get("checkpoint_id")
|
||||
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
# Derive next tasks
|
||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||
|
||||
entries.append(
|
||||
HistoryEntry(
|
||||
checkpoint_id=checkpoint_id,
|
||||
parent_checkpoint_id=parent_id,
|
||||
metadata=metadata,
|
||||
values=serialize_channel_values(channel_values),
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
next=next_tasks,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to get history for thread %s", thread_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
||||
|
||||
return entries
|
||||
168
deer-flow/backend/app/gateway/routers/uploads.py
Normal file
168
deer-flow/backend/app/gateway/routers/uploads.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Upload router for handling file uploads."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import stat
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
from deerflow.uploads.manager import (
|
||||
PathTraversalError,
|
||||
delete_file_safe,
|
||||
enrich_file_listing,
|
||||
ensure_uploads_dir,
|
||||
get_uploads_dir,
|
||||
list_files_in_dir,
|
||||
normalize_filename,
|
||||
upload_artifact_url,
|
||||
upload_virtual_path,
|
||||
)
|
||||
from deerflow.utils.file_conversion import CONVERTIBLE_EXTENSIONS, convert_file_to_markdown
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/threads/{thread_id}/uploads", tags=["uploads"])
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
"""Response model for file upload."""
|
||||
|
||||
success: bool
|
||||
files: list[dict[str, str]]
|
||||
message: str
|
||||
|
||||
|
||||
def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
||||
"""Ensure uploaded files remain writable when mounted into non-local sandboxes.
|
||||
|
||||
In AIO sandbox mode, the gateway writes the authoritative host-side file
|
||||
first, then the sandbox runtime may rewrite the same mounted path. Granting
|
||||
world-writable access here prevents permission mismatches between the
|
||||
gateway user and the sandbox runtime user.
|
||||
"""
|
||||
file_stat = os.lstat(file_path)
|
||||
if stat.S_ISLNK(file_stat.st_mode):
|
||||
logger.warning("Skipping sandbox chmod for symlinked upload path: %s", file_path)
|
||||
return
|
||||
|
||||
writable_mode = stat.S_IMODE(file_stat.st_mode) | stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH
|
||||
chmod_kwargs = {"follow_symlinks": False} if os.chmod in os.supports_follow_symlinks else {}
|
||||
os.chmod(file_path, writable_mode, **chmod_kwargs)
|
||||
|
||||
|
||||
@router.post("", response_model=UploadResponse)
|
||||
async def upload_files(
|
||||
thread_id: str,
|
||||
files: list[UploadFile] = File(...),
|
||||
) -> UploadResponse:
|
||||
"""Upload multiple files to a thread's uploads directory."""
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
try:
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
||||
uploaded_files = []
|
||||
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||
sandbox = sandbox_provider.get(sandbox_id)
|
||||
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
continue
|
||||
|
||||
try:
|
||||
safe_filename = normalize_filename(file.filename)
|
||||
except ValueError:
|
||||
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
|
||||
continue
|
||||
|
||||
try:
|
||||
content = await file.read()
|
||||
file_path = uploads_dir / safe_filename
|
||||
file_path.write_bytes(content)
|
||||
|
||||
virtual_path = upload_virtual_path(safe_filename)
|
||||
|
||||
if sandbox_id != "local":
|
||||
_make_file_sandbox_writable(file_path)
|
||||
sandbox.update_file(virtual_path, content)
|
||||
|
||||
file_info = {
|
||||
"filename": safe_filename,
|
||||
"size": str(len(content)),
|
||||
"path": str(sandbox_uploads / safe_filename),
|
||||
"virtual_path": virtual_path,
|
||||
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
||||
}
|
||||
|
||||
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
|
||||
|
||||
file_ext = file_path.suffix.lower()
|
||||
if file_ext in CONVERTIBLE_EXTENSIONS:
|
||||
md_path = await convert_file_to_markdown(file_path)
|
||||
if md_path:
|
||||
md_virtual_path = upload_virtual_path(md_path.name)
|
||||
|
||||
if sandbox_id != "local":
|
||||
_make_file_sandbox_writable(md_path)
|
||||
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
||||
|
||||
file_info["markdown_file"] = md_path.name
|
||||
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
||||
file_info["markdown_virtual_path"] = md_virtual_path
|
||||
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
|
||||
|
||||
uploaded_files.append(file_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload {file.filename}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
||||
|
||||
return UploadResponse(
|
||||
success=True,
|
||||
files=uploaded_files,
|
||||
message=f"Successfully uploaded {len(uploaded_files)} file(s)",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/list", response_model=dict)
|
||||
async def list_uploaded_files(thread_id: str) -> dict:
|
||||
"""List all files in a thread's uploads directory."""
|
||||
try:
|
||||
uploads_dir = get_uploads_dir(thread_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
result = list_files_in_dir(uploads_dir)
|
||||
enrich_file_listing(result, thread_id)
|
||||
|
||||
# Gateway additionally includes the sandbox-relative path.
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
|
||||
for f in result["files"]:
|
||||
f["path"] = str(sandbox_uploads / f["filename"])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/{filename}")
|
||||
async def delete_uploaded_file(thread_id: str, filename: str) -> dict:
|
||||
"""Delete a file from a thread's uploads directory."""
|
||||
try:
|
||||
uploads_dir = get_uploads_dir(thread_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
try:
|
||||
return delete_file_safe(uploads_dir, filename, convertible_extensions=CONVERTIBLE_EXTENSIONS)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
|
||||
except PathTraversalError:
|
||||
raise HTTPException(status_code=400, detail="Invalid path")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete {filename}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete {filename}: {str(e)}")
|
||||
367
deer-flow/backend/app/gateway/services.py
Normal file
367
deer-flow/backend/app/gateway/services.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""Run lifecycle service layer.
|
||||
|
||||
Centralizes the business logic for creating runs, formatting SSE
|
||||
frames, and consuming stream bridge events. Router modules
|
||||
(``thread_runs``, ``runs``) are thin HTTP handlers that delegate here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
|
||||
from deerflow.runtime import (
|
||||
END_SENTINEL,
|
||||
HEARTBEAT_SENTINEL,
|
||||
ConflictError,
|
||||
DisconnectMode,
|
||||
RunManager,
|
||||
RunRecord,
|
||||
RunStatus,
|
||||
StreamBridge,
|
||||
UnsupportedStrategyError,
|
||||
run_agent,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSE formatting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def format_sse(event: str, data: Any, *, event_id: str | None = None) -> str:
|
||||
"""Format a single SSE frame.
|
||||
|
||||
Field order: ``event:`` -> ``data:`` -> ``id:`` (optional) -> blank line.
|
||||
This matches the LangGraph Platform wire format consumed by the
|
||||
``useStream`` React hook and the Python ``langgraph-sdk`` SSE decoder.
|
||||
"""
|
||||
payload = json.dumps(data, default=str, ensure_ascii=False)
|
||||
parts = [f"event: {event}", f"data: {payload}"]
|
||||
if event_id:
|
||||
parts.append(f"id: {event_id}")
|
||||
parts.append("")
|
||||
parts.append("")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input / config helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def normalize_stream_modes(raw: list[str] | str | None) -> list[str]:
|
||||
"""Normalize the stream_mode parameter to a list.
|
||||
|
||||
Default matches what ``useStream`` expects: values + messages-tuple.
|
||||
"""
|
||||
if raw is None:
|
||||
return ["values"]
|
||||
if isinstance(raw, str):
|
||||
return [raw]
|
||||
return raw if raw else ["values"]
|
||||
|
||||
|
||||
def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Convert LangGraph Platform input format to LangChain state dict."""
|
||||
if raw_input is None:
|
||||
return {}
|
||||
messages = raw_input.get("messages")
|
||||
if messages and isinstance(messages, list):
|
||||
converted = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role", msg.get("type", "user"))
|
||||
content = msg.get("content", "")
|
||||
if role in ("user", "human"):
|
||||
converted.append(HumanMessage(content=content))
|
||||
else:
|
||||
# TODO: handle other message types (system, ai, tool)
|
||||
converted.append(HumanMessage(content=content))
|
||||
else:
|
||||
converted.append(msg)
|
||||
return {**raw_input, "messages": converted}
|
||||
return raw_input
|
||||
|
||||
|
||||
_DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||
|
||||
|
||||
def resolve_agent_factory(assistant_id: str | None):
|
||||
"""Resolve the agent factory callable from config.
|
||||
|
||||
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
|
||||
injected into ``configurable`` — see :func:`build_run_config`. All
|
||||
``assistant_id`` values therefore map to the same factory; the routing
|
||||
happens inside ``make_lead_agent`` when it reads ``cfg["agent_name"]``.
|
||||
"""
|
||||
from deerflow.agents.lead_agent.agent import make_lead_agent
|
||||
|
||||
return make_lead_agent
|
||||
|
||||
|
||||
def build_run_config(
|
||||
thread_id: str,
|
||||
request_config: dict[str, Any] | None,
|
||||
metadata: dict[str, Any] | None,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a RunnableConfig dict for the agent.
|
||||
|
||||
When *assistant_id* refers to a custom agent (anything other than
|
||||
``"lead_agent"`` / ``None``), the name is forwarded as
|
||||
``configurable["agent_name"]``. ``make_lead_agent`` reads this key to
|
||||
load the matching ``agents/<name>/SOUL.md`` and per-agent config —
|
||||
without it the agent silently runs as the default lead agent.
|
||||
|
||||
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
||||
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
||||
identically.
|
||||
"""
|
||||
config: dict[str, Any] = {"recursion_limit": 100}
|
||||
if request_config:
|
||||
# LangGraph >= 0.6.0 introduced ``context`` as the preferred way to
|
||||
# pass thread-level data and rejects requests that include both
|
||||
# ``configurable`` and ``context``. If the caller already sends
|
||||
# ``context``, honour it and skip our own ``configurable`` dict.
|
||||
if "context" in request_config:
|
||||
if "configurable" in request_config:
|
||||
logger.warning(
|
||||
"build_run_config: client sent both 'context' and 'configurable'; preferring 'context' (LangGraph >= 0.6.0). thread_id=%s, caller_configurable keys=%s",
|
||||
thread_id,
|
||||
list(request_config.get("configurable", {}).keys()),
|
||||
)
|
||||
config["context"] = request_config["context"]
|
||||
else:
|
||||
configurable = {"thread_id": thread_id}
|
||||
configurable.update(request_config.get("configurable", {}))
|
||||
config["configurable"] = configurable
|
||||
for k, v in request_config.items():
|
||||
if k not in ("configurable", "context"):
|
||||
config[k] = v
|
||||
else:
|
||||
config["configurable"] = {"thread_id": thread_id}
|
||||
|
||||
# Inject custom agent name when the caller specified a non-default assistant.
|
||||
# Honour an explicit configurable["agent_name"] in the request if already set.
|
||||
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "configurable" in config:
|
||||
if "agent_name" not in config["configurable"]:
|
||||
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
||||
config["configurable"]["agent_name"] = normalized
|
||||
if metadata:
|
||||
config.setdefault("metadata", {}).update(metadata)
|
||||
return config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _upsert_thread_in_store(store, thread_id: str, metadata: dict | None) -> None:
|
||||
"""Create or refresh the thread record in the Store.
|
||||
|
||||
Called from :func:`start_run` so that threads created via the stateless
|
||||
``/runs/stream`` endpoint (which never calls ``POST /threads``) still
|
||||
appear in ``/threads/search`` results.
|
||||
"""
|
||||
# Deferred import to avoid circular import with the threads router module.
|
||||
from app.gateway.routers.threads import _store_upsert
|
||||
|
||||
try:
|
||||
await _store_upsert(store, thread_id, metadata=metadata)
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread %s in store (non-fatal)", thread_id)
|
||||
|
||||
|
||||
async def _sync_thread_title_after_run(
|
||||
run_task: asyncio.Task,
|
||||
thread_id: str,
|
||||
checkpointer: Any,
|
||||
store: Any,
|
||||
) -> None:
|
||||
"""Wait for *run_task* to finish, then persist the generated title to the Store.
|
||||
|
||||
TitleMiddleware writes the generated title to the LangGraph agent state
|
||||
(checkpointer) but the Gateway's Store record is not updated automatically.
|
||||
This coroutine closes that gap by reading the final checkpoint after the
|
||||
run completes and syncing ``values.title`` into the Store record so that
|
||||
subsequent ``/threads/search`` responses include the correct title.
|
||||
|
||||
Runs as a fire-and-forget :func:`asyncio.create_task`; failures are
|
||||
logged at DEBUG level and never propagate.
|
||||
"""
|
||||
# Wait for the background run task to complete (any outcome).
|
||||
# asyncio.wait does not propagate task exceptions — it just returns
|
||||
# when the task is done, cancelled, or failed.
|
||||
await asyncio.wait({run_task})
|
||||
|
||||
# Deferred import to avoid circular import with the threads router module.
|
||||
from app.gateway.routers.threads import _store_get, _store_put
|
||||
|
||||
try:
|
||||
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
|
||||
if ckpt_tuple is None:
|
||||
return
|
||||
|
||||
channel_values = ckpt_tuple.checkpoint.get("channel_values", {})
|
||||
title = channel_values.get("title")
|
||||
if not title:
|
||||
return
|
||||
|
||||
existing = await _store_get(store, thread_id)
|
||||
if existing is None:
|
||||
return
|
||||
|
||||
updated = dict(existing)
|
||||
updated.setdefault("values", {})["title"] = title
|
||||
updated["updated_at"] = time.time()
|
||||
await _store_put(store, updated)
|
||||
logger.debug("Synced title %r for thread %s", title, thread_id)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id, exc_info=True)
|
||||
|
||||
|
||||
async def start_run(
|
||||
body: Any,
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
) -> RunRecord:
|
||||
"""Create a RunRecord and launch the background agent task.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
body : RunCreateRequest
|
||||
The validated request body (typed as Any to avoid circular import
|
||||
with the router module that defines the Pydantic model).
|
||||
thread_id : str
|
||||
Target thread.
|
||||
request : Request
|
||||
FastAPI request — used to retrieve singletons from ``app.state``.
|
||||
"""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
store = get_store(request)
|
||||
|
||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
body.assistant_id,
|
||||
on_disconnect=disconnect,
|
||||
metadata=body.metadata or {},
|
||||
kwargs={"input": body.input, "config": body.config},
|
||||
multitask_strategy=body.multitask_strategy,
|
||||
)
|
||||
except ConflictError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
except UnsupportedStrategyError as exc:
|
||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||
|
||||
# Ensure the thread is visible in /threads/search, even for threads that
|
||||
# were never explicitly created via POST /threads (e.g. stateless runs).
|
||||
store = get_store(request)
|
||||
if store is not None:
|
||||
await _upsert_thread_in_store(store, thread_id, body.metadata)
|
||||
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||
|
||||
# Merge DeerFlow-specific context overrides into configurable.
|
||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||
context = getattr(body, "context", None)
|
||||
if context:
|
||||
_CONTEXT_CONFIGURABLE_KEYS = {
|
||||
"model_name",
|
||||
"mode",
|
||||
"thinking_enabled",
|
||||
"reasoning_effort",
|
||||
"is_plan_mode",
|
||||
"subagent_enabled",
|
||||
"max_concurrent_subagents",
|
||||
}
|
||||
configurable = config.setdefault("configurable", {})
|
||||
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
||||
if key in context:
|
||||
configurable.setdefault(key, context[key])
|
||||
|
||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_agent(
|
||||
bridge,
|
||||
run_mgr,
|
||||
record,
|
||||
checkpointer=checkpointer,
|
||||
store=store,
|
||||
agent_factory=agent_factory,
|
||||
graph_input=graph_input,
|
||||
config=config,
|
||||
stream_modes=stream_modes,
|
||||
stream_subgraphs=body.stream_subgraphs,
|
||||
interrupt_before=body.interrupt_before,
|
||||
interrupt_after=body.interrupt_after,
|
||||
)
|
||||
)
|
||||
record.task = task
|
||||
|
||||
# After the run completes, sync the title generated by TitleMiddleware from
|
||||
# the checkpointer into the Store record so that /threads/search returns the
|
||||
# correct title instead of an empty values dict.
|
||||
if store is not None:
|
||||
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store))
|
||||
|
||||
return record
|
||||
|
||||
|
||||
async def sse_consumer(
|
||||
bridge: StreamBridge,
|
||||
record: RunRecord,
|
||||
request: Request,
|
||||
run_mgr: RunManager,
|
||||
):
|
||||
"""Async generator that yields SSE frames from the bridge.
|
||||
|
||||
The ``finally`` block implements ``on_disconnect`` semantics:
|
||||
- ``cancel``: abort the background task on client disconnect.
|
||||
- ``continue``: let the task run; events are discarded.
|
||||
"""
|
||||
last_event_id = request.headers.get("Last-Event-ID")
|
||||
try:
|
||||
async for entry in bridge.subscribe(record.run_id, last_event_id=last_event_id):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
if entry is HEARTBEAT_SENTINEL:
|
||||
yield ": heartbeat\n\n"
|
||||
continue
|
||||
|
||||
if entry is END_SENTINEL:
|
||||
yield format_sse("end", None, event_id=entry.id or None)
|
||||
return
|
||||
|
||||
yield format_sse(entry.event, entry.data, event_id=entry.id or None)
|
||||
|
||||
finally:
|
||||
if record.status in (RunStatus.pending, RunStatus.running):
|
||||
if record.on_disconnect == DisconnectMode.cancel:
|
||||
await run_mgr.cancel(record.run_id)
|
||||
Reference in New Issue
Block a user