Initial commit: hardened DeerFlow factory

Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection
hardening:

- New deerflow.security package: content_delimiter, html_cleaner,
  sanitizer (8 layers — invisible chars, control chars, symbols, NFC,
  PUA, tag chars, horizontal whitespace collapse with newline/tab
  preservation, length cap)
- New deerflow.community.searx package: web_search, web_fetch,
  image_search backed by a private SearX instance, every external
  string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>>
  delimiters
- All native community web providers (ddg_search, tavily, exa,
  firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail
  stubs that raise NativeWebToolDisabledError at import time, so a
  misconfigured tool.use path fails loud rather than silently falling
  back to unsanitized output
- Native client back-doors (jina_client.py, infoquest_client.py)
  stubbed too
- Native-tool tests quarantined under tests/_disabled_native/
  (collect_ignore_glob via local conftest.py)
- Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve
  newlines and tabs so list/table structure survives
- Hardened runtime config.yaml references only the searx-backed tools
- Factory overlay (backend/) kept in sync with deer-flow tree as a
  reference / source

See HARDENING.md for the full audit trail and verification steps.
This commit is contained in:
2026-04-12 14:23:57 +02:00
commit 6de0bf9f5b
889 changed files with 173052 additions and 0 deletions

View File

@@ -0,0 +1,24 @@
from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointer
from .factory import create_deerflow_agent
from .features import Next, Prev, RuntimeFeatures
from .lead_agent import make_lead_agent
from .lead_agent.prompt import prime_enabled_skills_cache
from .thread_state import SandboxState, ThreadState
# LangGraph imports deerflow.agents when registering the graph. Prime the
# enabled-skills cache here so the request path can usually read a warm cache
# without forcing synchronous filesystem work during prompt module import.
prime_enabled_skills_cache()
__all__ = [
"create_deerflow_agent",
"RuntimeFeatures",
"Next",
"Prev",
"make_lead_agent",
"SandboxState",
"ThreadState",
"get_checkpointer",
"reset_checkpointer",
"make_checkpointer",
]

View File

@@ -0,0 +1,9 @@
from .async_provider import make_checkpointer
from .provider import checkpointer_context, get_checkpointer, reset_checkpointer
__all__ = [
"get_checkpointer",
"reset_checkpointer",
"checkpointer_context",
"make_checkpointer",
]

View File

@@ -0,0 +1,106 @@
"""Async checkpointer factory.
Provides an **async context manager** for long-running async servers that need
proper resource cleanup.
Supported backends: memory, sqlite, postgres.
Usage (e.g. FastAPI lifespan)::
from deerflow.agents.checkpointer.async_provider import make_checkpointer
async with make_checkpointer() as checkpointer:
app.state.checkpointer = checkpointer # InMemorySaver if not configured
For sync usage see :mod:`deerflow.agents.checkpointer.provider`.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import AsyncIterator
from langgraph.types import Checkpointer
from deerflow.agents.checkpointer.provider import (
POSTGRES_CONN_REQUIRED,
POSTGRES_INSTALL,
SQLITE_INSTALL,
)
from deerflow.config.app_config import get_app_config
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Async factory
# ---------------------------------------------------------------------------
@contextlib.asynccontextmanager
async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
"""Async context manager that constructs and tears down a checkpointer."""
if config.type == "memory":
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
if config.type == "sqlite":
try:
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
await asyncio.to_thread(ensure_sqlite_parent_dir, conn_str)
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
await saver.setup()
yield saver
return
if config.type == "postgres":
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not config.connection_string:
raise ValueError(POSTGRES_CONN_REQUIRED)
async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver:
await saver.setup()
yield saver
return
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
# ---------------------------------------------------------------------------
# Public async context manager
# ---------------------------------------------------------------------------
@contextlib.asynccontextmanager
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
"""Async context manager that yields a checkpointer for the caller's lifetime.
Resources are opened on enter and closed on exit — no global state::
async with make_checkpointer() as checkpointer:
app.state.checkpointer = checkpointer
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
async with _async_checkpointer(config.checkpointer) as saver:
yield saver

View File

@@ -0,0 +1,191 @@
"""Sync checkpointer factory.
Provides a **sync singleton** and a **sync context manager** for LangGraph
graph compilation and CLI tools.
Supported backends: memory, sqlite, postgres.
Usage::
from deerflow.agents.checkpointer.provider import get_checkpointer, checkpointer_context
# Singleton — reused across calls, closed on process exit
cp = get_checkpointer()
# One-shot — fresh connection, closed on block exit
with checkpointer_context() as cp:
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
"""
from __future__ import annotations
import contextlib
import logging
from collections.abc import Iterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Error message constants — imported by aio.provider too
# ---------------------------------------------------------------------------
SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite"
POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
# ---------------------------------------------------------------------------
# Sync factory
# ---------------------------------------------------------------------------
@contextlib.contextmanager
def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
"""Context manager that creates and tears down a sync checkpointer.
Returns a configured ``Checkpointer`` instance. Resource cleanup for any
underlying connections or pools is handled by higher-level helpers in
this module (such as the singleton factory or context manager); this
function does not return a separate cleanup callback.
"""
if config.type == "memory":
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
yield InMemorySaver()
return
if config.type == "sqlite":
try:
from langgraph.checkpoint.sqlite import SqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
with SqliteSaver.from_conn_string(conn_str) as saver:
saver.setup()
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
yield saver
return
if config.type == "postgres":
try:
from langgraph.checkpoint.postgres import PostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not config.connection_string:
raise ValueError(POSTGRES_CONN_REQUIRED)
with PostgresSaver.from_conn_string(config.connection_string) as saver:
saver.setup()
logger.info("Checkpointer: using PostgresSaver")
yield saver
return
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
# ---------------------------------------------------------------------------
# Sync singleton
# ---------------------------------------------------------------------------
_checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive
def get_checkpointer() -> Checkpointer:
"""Return the global sync checkpointer singleton, creating it on first call.
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
Raises:
ImportError: If the required package for the configured backend is not installed.
ValueError: If ``connection_string`` is missing for a backend that requires it.
"""
global _checkpointer, _checkpointer_ctx
if _checkpointer is not None:
return _checkpointer
# Ensure app config is loaded before checking checkpointer config
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
# but hasn't been loaded yet
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config()
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
_checkpointer = InMemorySaver()
return _checkpointer
_checkpointer_ctx = _sync_checkpointer_cm(config)
_checkpointer = _checkpointer_ctx.__enter__()
return _checkpointer
def reset_checkpointer() -> None:
"""Reset the sync singleton, forcing recreation on the next call.
Closes any open backend connections and clears the cached instance.
Useful in tests or after a configuration change.
"""
global _checkpointer, _checkpointer_ctx
if _checkpointer_ctx is not None:
try:
_checkpointer_ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during checkpointer cleanup", exc_info=True)
_checkpointer_ctx = None
_checkpointer = None
# ---------------------------------------------------------------------------
# Sync context manager
# ---------------------------------------------------------------------------
@contextlib.contextmanager
def checkpointer_context() -> Iterator[Checkpointer]:
"""Sync context manager that yields a checkpointer and cleans up on exit.
Unlike :func:`get_checkpointer`, this does **not** cache the instance —
each ``with`` block creates and destroys its own connection. Use it in
CLI scripts or tests where you want deterministic cleanup::
with checkpointer_context() as cp:
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
with _sync_checkpointer_cm(config.checkpointer) as saver:
yield saver

View File

@@ -0,0 +1,372 @@
"""Pure-argument factory for DeerFlow agents.
``create_deerflow_agent`` accepts plain Python arguments — no YAML files, no
global singletons. It is the SDK-level entry point sitting between the raw
``langchain.agents.create_agent`` primitive and the config-driven
``make_lead_agent`` application factory.
Note: the factory assembly itself is config-free, but some injected runtime
components (e.g. ``task_tool`` for subagent) may still read global config at
invocation time. Full config-free runtime is a Phase 2 goal.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
from deerflow.agents.features import RuntimeFeatures
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
from deerflow.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
from deerflow.agents.middlewares.tool_error_handling_middleware import ToolErrorHandlingMiddleware
from deerflow.agents.thread_state import ThreadState
from deerflow.tools.builtins import ask_clarification_tool
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph.state import CompiledStateGraph
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# TodoMiddleware prompts (minimal SDK version)
# ---------------------------------------------------------------------------
_TODO_SYSTEM_PROMPT = """
<todo_list_system>
You have access to the `write_todos` tool to help you manage and track complex multi-step objectives.
**CRITICAL RULES:**
- Mark todos as completed IMMEDIATELY after finishing each step - do NOT batch completions
- Keep EXACTLY ONE task as `in_progress` at any time (unless tasks can run in parallel)
- Update the todo list in REAL-TIME as you work - this gives users visibility into your progress
- DO NOT use this tool for simple tasks (< 3 steps) - just complete them directly
</todo_list_system>
"""
_TODO_TOOL_DESCRIPTION = "Use this tool to create and manage a structured task list for complex work sessions. Only use for complex tasks (3+ steps)."
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def create_deerflow_agent(
model: BaseChatModel,
tools: list[BaseTool] | None = None,
*,
system_prompt: str | None = None,
middleware: list[AgentMiddleware] | None = None,
features: RuntimeFeatures | None = None,
extra_middleware: list[AgentMiddleware] | None = None,
plan_mode: bool = False,
state_schema: type | None = None,
checkpointer: BaseCheckpointSaver | None = None,
name: str = "default",
) -> CompiledStateGraph:
"""Create a DeerFlow agent from plain Python arguments.
The factory assembly itself reads no config files. Some injected runtime
components (e.g. ``task_tool``) may still depend on global config at
invocation time — see Phase 2 roadmap for full config-free runtime.
Parameters
----------
model:
Chat model instance.
tools:
User-provided tools. Feature-injected tools are appended automatically.
system_prompt:
System message. ``None`` uses a minimal default.
middleware:
**Full takeover** — if provided, this exact list is used.
Cannot be combined with *features* or *extra_middleware*.
features:
Declarative feature flags. Cannot be combined with *middleware*.
extra_middleware:
Additional middlewares inserted into the auto-assembled chain via
``@Next``/``@Prev`` positioning. Cannot be used with *middleware*.
plan_mode:
Enable TodoMiddleware for task tracking.
state_schema:
LangGraph state type. Defaults to ``ThreadState``.
checkpointer:
Optional persistence backend.
name:
Agent name (passed to middleware that cares, e.g. ``MemoryMiddleware``).
Raises
------
ValueError
If both *middleware* and *features*/*extra_middleware* are provided.
"""
if middleware is not None and features is not None:
raise ValueError("Cannot specify both 'middleware' and 'features'. Use one or the other.")
if middleware is not None and extra_middleware:
raise ValueError("Cannot use 'extra_middleware' with 'middleware' (full takeover).")
if extra_middleware:
for mw in extra_middleware:
if not isinstance(mw, AgentMiddleware):
raise TypeError(f"extra_middleware items must be AgentMiddleware instances, got {type(mw).__name__}")
effective_tools: list[BaseTool] = list(tools or [])
effective_state = state_schema or ThreadState
if middleware is not None:
effective_middleware = list(middleware)
else:
feat = features or RuntimeFeatures()
effective_middleware, extra_tools = _assemble_from_features(
feat,
name=name,
plan_mode=plan_mode,
extra_middleware=extra_middleware or [],
)
# Deduplicate by tool name — user-provided tools take priority.
existing_names = {t.name for t in effective_tools}
for t in extra_tools:
if t.name not in existing_names:
effective_tools.append(t)
existing_names.add(t.name)
return create_agent(
model=model,
tools=effective_tools or None,
middleware=effective_middleware,
system_prompt=system_prompt,
state_schema=effective_state,
checkpointer=checkpointer,
name=name,
)
# ---------------------------------------------------------------------------
# Internal: feature-driven middleware assembly
# ---------------------------------------------------------------------------
def _assemble_from_features(
feat: RuntimeFeatures,
*,
name: str = "default",
plan_mode: bool = False,
extra_middleware: list[AgentMiddleware] | None = None,
) -> tuple[list[AgentMiddleware], list[BaseTool]]:
"""Build an ordered middleware chain + extra tools from *feat*.
Middleware order matches ``make_lead_agent`` (14 middlewares):
0-2. Sandbox infrastructure (ThreadData → Uploads → Sandbox)
3. DanglingToolCallMiddleware (always)
4. GuardrailMiddleware (guardrail feature)
5. ToolErrorHandlingMiddleware (always)
6. SummarizationMiddleware (summarization feature)
7. TodoMiddleware (plan_mode parameter)
8. TitleMiddleware (auto_title feature)
9. MemoryMiddleware (memory feature)
10. ViewImageMiddleware (vision feature)
11. SubagentLimitMiddleware (subagent feature)
12. LoopDetectionMiddleware (always)
13. ClarificationMiddleware (always last)
Two-phase ordering:
1. Built-in chain — fixed sequential append.
2. Extra middleware — inserted via @Next/@Prev.
Each feature value is handled as:
- ``False``: skip
- ``True``: create the built-in default middleware (not available for
``summarization`` and ``guardrail`` — these require a custom instance)
- ``AgentMiddleware`` instance: use directly (custom replacement)
"""
chain: list[AgentMiddleware] = []
extra_tools: list[BaseTool] = []
# --- [0-2] Sandbox infrastructure ---
if feat.sandbox is not False:
if isinstance(feat.sandbox, AgentMiddleware):
chain.append(feat.sandbox)
else:
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
from deerflow.sandbox.middleware import SandboxMiddleware
chain.append(ThreadDataMiddleware(lazy_init=True))
chain.append(UploadsMiddleware())
chain.append(SandboxMiddleware(lazy_init=True))
# --- [3] DanglingToolCall (always) ---
chain.append(DanglingToolCallMiddleware())
# --- [4] Guardrail ---
if feat.guardrail is not False:
if isinstance(feat.guardrail, AgentMiddleware):
chain.append(feat.guardrail)
else:
raise ValueError("guardrail=True requires a custom AgentMiddleware instance (no built-in GuardrailMiddleware yet)")
# --- [5] ToolErrorHandling (always) ---
chain.append(ToolErrorHandlingMiddleware())
# --- [6] Summarization ---
if feat.summarization is not False:
if isinstance(feat.summarization, AgentMiddleware):
chain.append(feat.summarization)
else:
raise ValueError("summarization=True requires a custom AgentMiddleware instance (SummarizationMiddleware needs a model argument)")
# --- [7] TodoMiddleware (plan_mode) ---
if plan_mode:
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
chain.append(TodoMiddleware(system_prompt=_TODO_SYSTEM_PROMPT, tool_description=_TODO_TOOL_DESCRIPTION))
# --- [8] Auto Title ---
if feat.auto_title is not False:
if isinstance(feat.auto_title, AgentMiddleware):
chain.append(feat.auto_title)
else:
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
chain.append(TitleMiddleware())
# --- [9] Memory ---
if feat.memory is not False:
if isinstance(feat.memory, AgentMiddleware):
chain.append(feat.memory)
else:
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
chain.append(MemoryMiddleware(agent_name=name))
# --- [10] Vision ---
if feat.vision is not False:
if isinstance(feat.vision, AgentMiddleware):
chain.append(feat.vision)
else:
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
chain.append(ViewImageMiddleware())
from deerflow.tools.builtins import view_image_tool
extra_tools.append(view_image_tool)
# --- [11] Subagent ---
if feat.subagent is not False:
if isinstance(feat.subagent, AgentMiddleware):
chain.append(feat.subagent)
else:
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
chain.append(SubagentLimitMiddleware())
from deerflow.tools.builtins import task_tool
extra_tools.append(task_tool)
# --- [12] LoopDetection (always) ---
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
chain.append(LoopDetectionMiddleware())
# --- [13] Clarification (always last among built-ins) ---
chain.append(ClarificationMiddleware())
extra_tools.append(ask_clarification_tool)
# --- Insert extra_middleware via @Next/@Prev ---
if extra_middleware:
_insert_extra(chain, extra_middleware)
# Invariant: ClarificationMiddleware must always be last.
# @Next(ClarificationMiddleware) could push it off the tail.
clar_idx = next(i for i, m in enumerate(chain) if isinstance(m, ClarificationMiddleware))
if clar_idx != len(chain) - 1:
chain.append(chain.pop(clar_idx))
return chain, extra_tools
# ---------------------------------------------------------------------------
# Internal: extra middleware insertion with @Next/@Prev
# ---------------------------------------------------------------------------
def _insert_extra(chain: list[AgentMiddleware], extras: list[AgentMiddleware]) -> None:
"""Insert extra middlewares into *chain* using ``@Next``/``@Prev`` anchors.
Algorithm:
1. Validate: no middleware has both @Next and @Prev.
2. Conflict detection: two extras targeting same anchor (same or opposite direction) → error.
3. Insert unanchored extras before ClarificationMiddleware.
4. Insert anchored extras iteratively (supports cross-external anchoring).
5. If an anchor cannot be resolved after all rounds → error.
"""
next_targets: dict[type, type] = {}
prev_targets: dict[type, type] = {}
anchored: list[tuple[AgentMiddleware, str, type]] = []
unanchored: list[AgentMiddleware] = []
for mw in extras:
next_anchor = getattr(type(mw), "_next_anchor", None)
prev_anchor = getattr(type(mw), "_prev_anchor", None)
if next_anchor and prev_anchor:
raise ValueError(f"{type(mw).__name__} cannot have both @Next and @Prev")
if next_anchor:
if next_anchor in next_targets:
raise ValueError(f"Conflict: {type(mw).__name__} and {next_targets[next_anchor].__name__} both @Next({next_anchor.__name__})")
if next_anchor in prev_targets:
raise ValueError(f"Conflict: {type(mw).__name__} @Next({next_anchor.__name__}) and {prev_targets[next_anchor].__name__} @Prev({next_anchor.__name__}) — use cross-anchoring between extras instead")
next_targets[next_anchor] = type(mw)
anchored.append((mw, "next", next_anchor))
elif prev_anchor:
if prev_anchor in prev_targets:
raise ValueError(f"Conflict: {type(mw).__name__} and {prev_targets[prev_anchor].__name__} both @Prev({prev_anchor.__name__})")
if prev_anchor in next_targets:
raise ValueError(f"Conflict: {type(mw).__name__} @Prev({prev_anchor.__name__}) and {next_targets[prev_anchor].__name__} @Next({prev_anchor.__name__}) — use cross-anchoring between extras instead")
prev_targets[prev_anchor] = type(mw)
anchored.append((mw, "prev", prev_anchor))
else:
unanchored.append(mw)
# Unanchored → before ClarificationMiddleware
clarification_idx = next(i for i, m in enumerate(chain) if isinstance(m, ClarificationMiddleware))
for mw in unanchored:
chain.insert(clarification_idx, mw)
clarification_idx += 1
# Anchored → iterative insertion (supports external-to-external anchoring)
pending = list(anchored)
max_rounds = len(pending) + 1
for _ in range(max_rounds):
if not pending:
break
remaining = []
for mw, direction, anchor in pending:
idx = next(
(i for i, m in enumerate(chain) if isinstance(m, anchor)),
None,
)
if idx is None:
remaining.append((mw, direction, anchor))
continue
if direction == "next":
chain.insert(idx + 1, mw)
else:
chain.insert(idx, mw)
if len(remaining) == len(pending):
names = [type(m).__name__ for m, _, _ in remaining]
anchor_types = {a for _, _, a in remaining}
remaining_types = {type(m) for m, _, _ in remaining}
circular = anchor_types & remaining_types
if circular:
raise ValueError(f"Circular dependency among extra middlewares: {', '.join(t.__name__ for t in circular)}")
raise ValueError(f"Cannot resolve positions for {', '.join(names)} — anchors {', '.join(a.__name__ for _, _, a in remaining)} not found in chain")
pending = remaining

View File

@@ -0,0 +1,62 @@
"""Declarative feature flags and middleware positioning for create_deerflow_agent.
Pure data classes and decorators — no I/O, no side effects.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
from langchain.agents.middleware import AgentMiddleware
@dataclass
class RuntimeFeatures:
"""Declarative feature flags for ``create_deerflow_agent``.
Most features accept:
- ``True``: use the built-in default middleware
- ``False``: disable
- An ``AgentMiddleware`` instance: use this custom implementation instead
``summarization`` and ``guardrail`` have no built-in default — they only
accept ``False`` (disable) or an ``AgentMiddleware`` instance (custom).
"""
sandbox: bool | AgentMiddleware = True
memory: bool | AgentMiddleware = False
summarization: Literal[False] | AgentMiddleware = False
subagent: bool | AgentMiddleware = False
vision: bool | AgentMiddleware = False
auto_title: bool | AgentMiddleware = False
guardrail: Literal[False] | AgentMiddleware = False
# ---------------------------------------------------------------------------
# Middleware positioning decorators
# ---------------------------------------------------------------------------
def Next(anchor: type[AgentMiddleware]):
"""Declare this middleware should be placed after *anchor* in the chain."""
if not (isinstance(anchor, type) and issubclass(anchor, AgentMiddleware)):
raise TypeError(f"@Next expects an AgentMiddleware subclass, got {anchor!r}")
def decorator(cls: type[AgentMiddleware]) -> type[AgentMiddleware]:
cls._next_anchor = anchor # type: ignore[attr-defined]
return cls
return decorator
def Prev(anchor: type[AgentMiddleware]):
"""Declare this middleware should be placed before *anchor* in the chain."""
if not (isinstance(anchor, type) and issubclass(anchor, AgentMiddleware)):
raise TypeError(f"@Prev expects an AgentMiddleware subclass, got {anchor!r}")
def decorator(cls: type[AgentMiddleware]) -> type[AgentMiddleware]:
cls._prev_anchor = anchor # type: ignore[attr-defined]
return cls
return decorator

View File

@@ -0,0 +1,3 @@
from .agent import make_lead_agent
__all__ = ["make_lead_agent"]

View File

@@ -0,0 +1,350 @@
import logging
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
from langchain_core.runnables import RunnableConfig
from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config
from deerflow.config.app_config import get_app_config
from deerflow.config.summarization_config import get_summarization_config
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
def _resolve_model_name(requested_model_name: str | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = get_app_config()
default_model_name = app_config.models[0].name if app_config.models else None
if default_model_name is None:
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
if requested_model_name and app_config.get_model_config(requested_model_name):
return requested_model_name
if requested_model_name and requested_model_name != default_model_name:
logger.warning(f"Model '{requested_model_name}' not found in config; fallback to default model '{default_model_name}'.")
return default_model_name
def _create_summarization_middleware() -> SummarizationMiddleware | None:
"""Create and configure the summarization middleware from config."""
config = get_summarization_config()
if not config.enabled:
return None
# Prepare trigger parameter
trigger = None
if config.trigger is not None:
if isinstance(config.trigger, list):
trigger = [t.to_tuple() for t in config.trigger]
else:
trigger = config.trigger.to_tuple()
# Prepare keep parameter
keep = config.keep.to_tuple()
# Prepare model parameter
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
else:
# Use a lightweight model for summarization to save costs
# Falls back to default model if not explicitly specified
model = create_chat_model(thinking_enabled=False)
# Prepare kwargs
kwargs = {
"model": model,
"trigger": trigger,
"keep": keep,
}
if config.trim_tokens_to_summarize is not None:
kwargs["trim_tokens_to_summarize"] = config.trim_tokens_to_summarize
if config.summary_prompt is not None:
kwargs["summary_prompt"] = config.summary_prompt
return SummarizationMiddleware(**kwargs)
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
"""Create and configure the TodoList middleware.
Args:
is_plan_mode: Whether to enable plan mode with TodoList middleware.
Returns:
TodoMiddleware instance if plan mode is enabled, None otherwise.
"""
if not is_plan_mode:
return None
# Custom prompts matching DeerFlow's style
system_prompt = """
<todo_list_system>
You have access to the `write_todos` tool to help you manage and track complex multi-step objectives.
**CRITICAL RULES:**
- Mark todos as completed IMMEDIATELY after finishing each step - do NOT batch completions
- Keep EXACTLY ONE task as `in_progress` at any time (unless tasks can run in parallel)
- Update the todo list in REAL-TIME as you work - this gives users visibility into your progress
- DO NOT use this tool for simple tasks (< 3 steps) - just complete them directly
**When to Use:**
This tool is designed for complex objectives that require systematic tracking:
- Complex multi-step tasks requiring 3+ distinct steps
- Non-trivial tasks needing careful planning and execution
- User explicitly requests a todo list
- User provides multiple tasks (numbered or comma-separated list)
- The plan may need revisions based on intermediate results
**When NOT to Use:**
- Single, straightforward tasks
- Trivial tasks (< 3 steps)
- Purely conversational or informational requests
- Simple tool calls where the approach is obvious
**Best Practices:**
- Break down complex tasks into smaller, actionable steps
- Use clear, descriptive task names
- Remove tasks that become irrelevant
- Add new tasks discovered during implementation
- Don't be afraid to revise the todo list as you learn more
**Task Management:**
Writing todos takes time and tokens - use it when helpful for managing complex problems, not for simple requests.
</todo_list_system>
"""
tool_description = """Use this tool to create and manage a structured task list for complex work sessions.
**IMPORTANT: Only use this tool for complex tasks (3+ steps). For simple requests, just do the work directly.**
## When to Use
Use this tool in these scenarios:
1. **Complex multi-step tasks**: When a task requires 3 or more distinct steps or actions
2. **Non-trivial tasks**: Tasks requiring careful planning or multiple operations
3. **User explicitly requests todo list**: When the user directly asks you to track tasks
4. **Multiple tasks**: When users provide a list of things to be done
5. **Dynamic planning**: When the plan may need updates based on intermediate results
## When NOT to Use
Skip this tool when:
1. The task is straightforward and takes less than 3 steps
2. The task is trivial and tracking provides no benefit
3. The task is purely conversational or informational
4. It's clear what needs to be done and you can just do it
## How to Use
1. **Starting a task**: Mark it as `in_progress` BEFORE beginning work
2. **Completing a task**: Mark it as `completed` IMMEDIATELY after finishing
3. **Updating the list**: Add new tasks, remove irrelevant ones, or update descriptions as needed
4. **Multiple updates**: You can make several updates at once (e.g., complete one task and start the next)
## Task States
- `pending`: Task not yet started
- `in_progress`: Currently working on (can have multiple if tasks run in parallel)
- `completed`: Task finished successfully
## Task Completion Requirements
**CRITICAL: Only mark a task as completed when you have FULLY accomplished it.**
Never mark a task as completed if:
- There are unresolved issues or errors
- Work is partial or incomplete
- You encountered blockers preventing completion
- You couldn't find necessary resources or dependencies
- Quality standards haven't been met
If blocked, keep the task as `in_progress` and create a new task describing what needs to be resolved.
## Best Practices
- Create specific, actionable items
- Break complex tasks into smaller, manageable steps
- Use clear, descriptive task names
- Update task status in real-time as you work
- Mark tasks complete IMMEDIATELY after finishing (don't batch completions)
- Remove tasks that are no longer relevant
- **IMPORTANT**: When you write the todo list, mark your first task(s) as `in_progress` immediately
- **IMPORTANT**: Unless all tasks are completed, always have at least one task `in_progress` to show progress
Being proactive with task management demonstrates thoroughness and ensures all requirements are completed successfully.
**Remember**: If you only need a few tool calls to complete a task and it's clear what to do, it's better to just do the task directly and NOT use this tool at all.
"""
return TodoMiddleware(system_prompt=system_prompt, tool_description=tool_description)
# ThreadDataMiddleware must be before SandboxMiddleware to ensure thread_id is available
# UploadsMiddleware should be after ThreadDataMiddleware to access thread_id
# DanglingToolCallMiddleware patches missing ToolMessages before model sees the history
# SummarizationMiddleware should be early to reduce context before other processing
# TodoListMiddleware should be before ClarificationMiddleware to allow todo management
# TitleMiddleware generates title after first exchange
# MemoryMiddleware queues conversation for memory update (after TitleMiddleware)
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
# ClarificationMiddleware should be last to intercept clarification requests after model calls
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None):
"""Build middleware chain based on runtime configuration.
Args:
config: Runtime configuration containing configurable options like is_plan_mode.
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
custom_middlewares: Optional list of custom middlewares to inject into the chain.
Returns:
List of middleware instances.
"""
middlewares = build_lead_runtime_middlewares(lazy_init=True)
# Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware()
if summarization_middleware is not None:
middlewares.append(summarization_middleware)
# Add TodoList middleware if plan mode is enabled
is_plan_mode = config.get("configurable", {}).get("is_plan_mode", False)
todo_list_middleware = _create_todo_list_middleware(is_plan_mode)
if todo_list_middleware is not None:
middlewares.append(todo_list_middleware)
# Add TokenUsageMiddleware when token_usage tracking is enabled
if get_app_config().token_usage.enabled:
middlewares.append(TokenUsageMiddleware())
# Add TitleMiddleware
middlewares.append(TitleMiddleware())
# Add MemoryMiddleware (after TitleMiddleware)
middlewares.append(MemoryMiddleware(agent_name=agent_name))
# Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
app_config = get_app_config()
model_config = app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware())
# Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding
if app_config.tool_search.enabled:
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
middlewares.append(DeferredToolFilterMiddleware())
# Add SubagentLimitMiddleware to truncate excess parallel task calls
subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False)
if subagent_enabled:
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3)
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
# LoopDetectionMiddleware — detect and break repetitive tool call loops
middlewares.append(LoopDetectionMiddleware())
# Inject custom middlewares before ClarificationMiddleware
if custom_middlewares:
middlewares.extend(custom_middlewares)
# ClarificationMiddleware should always be last
middlewares.append(ClarificationMiddleware())
return middlewares
def make_lead_agent(config: RunnableConfig):
# Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent
cfg = config.get("configurable", {})
thinking_enabled = cfg.get("thinking_enabled", True)
reasoning_effort = cfg.get("reasoning_effort", None)
requested_model_name: str | None = cfg.get("model_name") or cfg.get("model")
is_plan_mode = cfg.get("is_plan_mode", False)
subagent_enabled = cfg.get("subagent_enabled", False)
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
is_bootstrap = cfg.get("is_bootstrap", False)
agent_name = cfg.get("agent_name")
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
agent_model_name = agent_config.model if agent_config and agent_config.model else None
# Final model name resolution: request → agent config → global default, with fallback for unknown names
model_name = _resolve_model_name(requested_model_name or agent_model_name)
app_config = get_app_config()
model_config = app_config.get_model_config(model_name)
if model_config is None:
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
if thinking_enabled and not model_config.supports_thinking:
logger.warning(f"Thinking mode is enabled but model '{model_name}' does not support it; fallback to non-thinking mode.")
thinking_enabled = False
logger.info(
"Create Agent(%s) -> thinking_enabled: %s, reasoning_effort: %s, model_name: %s, is_plan_mode: %s, subagent_enabled: %s, max_concurrent_subagents: %s",
agent_name or "default",
thinking_enabled,
reasoning_effort,
model_name,
is_plan_mode,
subagent_enabled,
max_concurrent_subagents,
)
# Inject run metadata for LangSmith trace tagging
if "metadata" not in config:
config["metadata"] = {}
config["metadata"].update(
{
"agent_name": agent_name or "default",
"model_name": model_name or "default",
"thinking_enabled": thinking_enabled,
"reasoning_effort": reasoning_effort,
"is_plan_mode": is_plan_mode,
"subagent_enabled": subagent_enabled,
}
)
if is_bootstrap:
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent],
middleware=_build_middlewares(config, model_name=model_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
state_schema=ThreadState,
)
# Default lead agent (unchanged behavior)
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
),
state_schema=ThreadState,
)

View File

@@ -0,0 +1,727 @@
import asyncio
import logging
import threading
from datetime import datetime
from functools import lru_cache
from deerflow.config.agents_config import load_agent_soul
from deerflow.skills import load_skills
from deerflow.skills.types import Skill
from deerflow.subagents import get_available_subagent_names
logger = logging.getLogger(__name__)
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
_enabled_skills_lock = threading.Lock()
_enabled_skills_cache: list[Skill] | None = None
_enabled_skills_refresh_active = False
_enabled_skills_refresh_version = 0
_enabled_skills_refresh_event = threading.Event()
def _load_enabled_skills_sync() -> list[Skill]:
return list(load_skills(enabled_only=True))
def _start_enabled_skills_refresh_thread() -> None:
threading.Thread(
target=_refresh_enabled_skills_cache_worker,
name="deerflow-enabled-skills-loader",
daemon=True,
).start()
def _refresh_enabled_skills_cache_worker() -> None:
global _enabled_skills_cache, _enabled_skills_refresh_active
while True:
with _enabled_skills_lock:
target_version = _enabled_skills_refresh_version
try:
skills = _load_enabled_skills_sync()
except Exception:
logger.exception("Failed to load enabled skills for prompt injection")
skills = []
with _enabled_skills_lock:
if _enabled_skills_refresh_version == target_version:
_enabled_skills_cache = skills
_enabled_skills_refresh_active = False
_enabled_skills_refresh_event.set()
return
# A newer invalidation happened while loading. Keep the worker alive
# and loop again so the cache always converges on the latest version.
_enabled_skills_cache = None
def _ensure_enabled_skills_cache() -> threading.Event:
global _enabled_skills_refresh_active
with _enabled_skills_lock:
if _enabled_skills_cache is not None:
_enabled_skills_refresh_event.set()
return _enabled_skills_refresh_event
if _enabled_skills_refresh_active:
return _enabled_skills_refresh_event
_enabled_skills_refresh_active = True
_enabled_skills_refresh_event.clear()
_start_enabled_skills_refresh_thread()
return _enabled_skills_refresh_event
def _invalidate_enabled_skills_cache() -> threading.Event:
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
_get_cached_skills_prompt_section.cache_clear()
with _enabled_skills_lock:
_enabled_skills_cache = None
_enabled_skills_refresh_version += 1
_enabled_skills_refresh_event.clear()
if _enabled_skills_refresh_active:
return _enabled_skills_refresh_event
_enabled_skills_refresh_active = True
_start_enabled_skills_refresh_thread()
return _enabled_skills_refresh_event
def prime_enabled_skills_cache() -> None:
_ensure_enabled_skills_cache()
def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool:
if _ensure_enabled_skills_cache().wait(timeout=timeout_seconds):
return True
logger.warning("Timed out waiting %.1fs for enabled skills cache warm-up", timeout_seconds)
return False
def _get_enabled_skills():
with _enabled_skills_lock:
cached = _enabled_skills_cache
if cached is not None:
return list(cached)
_ensure_enabled_skills_cache()
return []
def _skill_mutability_label(category: str) -> str:
return "[custom, editable]" if category == "custom" else "[built-in]"
def clear_skills_system_prompt_cache() -> None:
_invalidate_enabled_skills_cache()
async def refresh_skills_system_prompt_cache_async() -> None:
await asyncio.to_thread(_invalidate_enabled_skills_cache().wait)
def _reset_skills_system_prompt_cache_state() -> None:
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
_get_cached_skills_prompt_section.cache_clear()
with _enabled_skills_lock:
_enabled_skills_cache = None
_enabled_skills_refresh_active = False
_enabled_skills_refresh_version = 0
_enabled_skills_refresh_event.clear()
def _refresh_enabled_skills_cache() -> None:
"""Backward-compatible test helper for direct synchronous reload."""
try:
skills = _load_enabled_skills_sync()
except Exception:
logger.exception("Failed to load enabled skills for prompt injection")
skills = []
with _enabled_skills_lock:
_enabled_skills_cache = skills
_enabled_skills_refresh_active = False
_enabled_skills_refresh_event.set()
def _build_skill_evolution_section(skill_evolution_enabled: bool) -> str:
if not skill_evolution_enabled:
return ""
return """
## Skill Self-Evolution
After completing a task, consider creating or updating a skill when:
- The task required 5+ tool calls to resolve
- You overcame non-obvious errors or pitfalls
- The user corrected your approach and the corrected version worked
- You discovered a non-trivial, recurring workflow
If you used a skill and encountered issues not covered by it, patch it immediately.
Prefer patch over edit. Before creating a new skill, confirm with the user first.
Skip simple one-off tasks.
"""
def _build_subagent_section(max_concurrent: int) -> str:
"""Build the subagent system prompt section with dynamic concurrency limit.
Args:
max_concurrent: Maximum number of concurrent subagent calls allowed per response.
Returns:
Formatted subagent section string.
"""
n = max_concurrent
bash_available = "bash" in get_available_subagent_names()
available_subagents = (
"- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n- **bash**: For command execution (git, build, test, deploy operations)"
if bash_available
else "- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n"
"- **bash**: Not available in the current sandbox configuration. Use direct file/web tools or switch to AioSandboxProvider for isolated shell access."
)
direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc."
direct_execution_example = (
'# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()'
if bash_available
else '# User asks: "Read the README"\n# Thinking: Single straightforward file read\n# → Execute directly\n\nread_file("/mnt/user-data/workspace/README.md") # Direct execution, not task()'
)
return f"""<subagent_system>
**🚀 SUBAGENT MODE ACTIVE - DECOMPOSE, DELEGATE, SYNTHESIZE**
You are running with subagent capabilities enabled. Your role is to be a **task orchestrator**:
1. **DECOMPOSE**: Break complex tasks into parallel sub-tasks
2. **DELEGATE**: Launch multiple subagents simultaneously using parallel `task` calls
3. **SYNTHESIZE**: Collect and integrate results into a coherent answer
**CORE PRINCIPLE: Complex tasks should be decomposed and distributed across multiple subagents for parallel execution.**
**⛔ HARD CONCURRENCY LIMIT: MAXIMUM {n} `task` CALLS PER RESPONSE. THIS IS NOT OPTIONAL.**
- Each response, you may include **at most {n}** `task` tool calls. Any excess calls are **silently discarded** by the system — you will lose that work.
- **Before launching subagents, you MUST count your sub-tasks in your thinking:**
- If count ≤ {n}: Launch all in this response.
- If count > {n}: **Pick the {n} most important/foundational sub-tasks for this turn.** Save the rest for the next turn.
- **Multi-batch execution** (for >{n} sub-tasks):
- Turn 1: Launch sub-tasks 1-{n} in parallel → wait for results
- Turn 2: Launch next batch in parallel → wait for results
- ... continue until all sub-tasks are complete
- Final turn: Synthesize ALL results into a coherent answer
- **Example thinking pattern**: "I identified 6 sub-tasks. Since the limit is {n} per turn, I will launch the first {n} now, and the rest in the next turn."
**Available Subagents:**
{available_subagents}
**Your Orchestration Strategy:**
✅ **DECOMPOSE + PARALLEL EXECUTION (Preferred Approach):**
For complex queries, break them down into focused sub-tasks and execute in parallel batches (max {n} per turn):
**Example 1: "Why is Tencent's stock price declining?" (3 sub-tasks → 1 batch)**
→ Turn 1: Launch 3 subagents in parallel:
- Subagent 1: Recent financial reports, earnings data, and revenue trends
- Subagent 2: Negative news, controversies, and regulatory issues
- Subagent 3: Industry trends, competitor performance, and market sentiment
→ Turn 2: Synthesize results
**Example 2: "Compare 5 cloud providers" (5 sub-tasks → multi-batch)**
→ Turn 1: Launch {n} subagents in parallel (first batch)
→ Turn 2: Launch remaining subagents in parallel
→ Final turn: Synthesize ALL results into comprehensive comparison
**Example 3: "Refactor the authentication system"**
→ Turn 1: Launch 3 subagents in parallel:
- Subagent 1: Analyze current auth implementation and technical debt
- Subagent 2: Research best practices and security patterns
- Subagent 3: Review related tests, documentation, and vulnerabilities
→ Turn 2: Synthesize results
✅ **USE Parallel Subagents (max {n} per turn) when:**
- **Complex research questions**: Requires multiple information sources or perspectives
- **Multi-aspect analysis**: Task has several independent dimensions to explore
- **Large codebases**: Need to analyze different parts simultaneously
- **Comprehensive investigations**: Questions requiring thorough coverage from multiple angles
❌ **DO NOT use subagents (execute directly) when:**
- **Task cannot be decomposed**: If you can't break it into 2+ meaningful parallel sub-tasks, execute directly
- **Ultra-simple actions**: Read one file, quick edits, single commands
- **Need immediate clarification**: Must ask user before proceeding
- **Meta conversation**: Questions about conversation history
- **Sequential dependencies**: Each step depends on previous results (do steps yourself sequentially)
**CRITICAL WORKFLOW** (STRICTLY follow this before EVERY action):
1. **COUNT**: In your thinking, list all sub-tasks and count them explicitly: "I have N sub-tasks"
2. **PLAN BATCHES**: If N > {n}, explicitly plan which sub-tasks go in which batch:
- "Batch 1 (this turn): first {n} sub-tasks"
- "Batch 2 (next turn): next batch of sub-tasks"
3. **EXECUTE**: Launch ONLY the current batch (max {n} `task` calls). Do NOT launch sub-tasks from future batches.
4. **REPEAT**: After results return, launch the next batch. Continue until all batches complete.
5. **SYNTHESIZE**: After ALL batches are done, synthesize all results.
6. **Cannot decompose** → Execute directly using available tools ({direct_tool_examples})
**⛔ VIOLATION: Launching more than {n} `task` calls in a single response is a HARD ERROR. The system WILL discard excess calls and you WILL lose work. Always batch.**
**Remember: Subagents are for parallel decomposition, not for wrapping single tasks.**
**How It Works:**
- The task tool runs subagents asynchronously in the background
- The backend automatically polls for completion (you don't need to poll)
- The tool call will block until the subagent completes its work
- Once complete, the result is returned to you directly
**Usage Example 1 - Single Batch (≤{n} sub-tasks):**
```python
# User asks: "Why is Tencent's stock price declining?"
# Thinking: 3 sub-tasks → fits in 1 batch
# Turn 1: Launch 3 subagents in parallel
task(description="Tencent financial data", prompt="...", subagent_type="general-purpose")
task(description="Tencent news & regulation", prompt="...", subagent_type="general-purpose")
task(description="Industry & market trends", prompt="...", subagent_type="general-purpose")
# All 3 run in parallel → synthesize results
```
**Usage Example 2 - Multiple Batches (>{n} sub-tasks):**
```python
# User asks: "Compare AWS, Azure, GCP, Alibaba Cloud, and Oracle Cloud"
# Thinking: 5 sub-tasks → need multiple batches (max {n} per batch)
# Turn 1: Launch first batch of {n}
task(description="AWS analysis", prompt="...", subagent_type="general-purpose")
task(description="Azure analysis", prompt="...", subagent_type="general-purpose")
task(description="GCP analysis", prompt="...", subagent_type="general-purpose")
# Turn 2: Launch remaining batch (after first batch completes)
task(description="Alibaba Cloud analysis", prompt="...", subagent_type="general-purpose")
task(description="Oracle Cloud analysis", prompt="...", subagent_type="general-purpose")
# Turn 3: Synthesize ALL results from both batches
```
**Counter-Example - Direct Execution (NO subagents):**
```python
{direct_execution_example}
```
**CRITICAL**:
- **Max {n} `task` calls per turn** - the system enforces this, excess calls are discarded
- Only use `task` when you can launch 2+ subagents in parallel
- Single task = No value from subagents = Execute directly
- For >{n} sub-tasks, use sequential batches of {n} across multiple turns
</subagent_system>"""
SYSTEM_PROMPT_TEMPLATE = """
<role>
You are {agent_name}, an open-source super agent.
</role>
{soul}
{memory_context}
<thinking_style>
- Think concisely and strategically about the user's request BEFORE taking action
- Break down the task: What is clear? What is ambiguous? What is missing?
- **PRIORITY CHECK: If anything is unclear, missing, or has multiple interpretations, you MUST ask for clarification FIRST - do NOT proceed with work**
{subagent_thinking}- Never write down your full final answer or report in thinking process, but only outline
- CRITICAL: After thinking, you MUST provide your actual response to the user. Thinking is for planning, the response is for delivery.
- Your response must contain the actual answer, not just a reference to what you thought about
</thinking_style>
<clarification_system>
**WORKFLOW PRIORITY: CLARIFY → PLAN → ACT**
1. **FIRST**: Analyze the request in your thinking - identify what's unclear, missing, or ambiguous
2. **SECOND**: If clarification is needed, call `ask_clarification` tool IMMEDIATELY - do NOT start working
3. **THIRD**: Only after all clarifications are resolved, proceed with planning and execution
**CRITICAL RULE: Clarification ALWAYS comes BEFORE action. Never start working and clarify mid-execution.**
**MANDATORY Clarification Scenarios - You MUST call ask_clarification BEFORE starting work when:**
1. **Missing Information** (`missing_info`): Required details not provided
- Example: User says "create a web scraper" but doesn't specify the target website
- Example: "Deploy the app" without specifying environment
- **REQUIRED ACTION**: Call ask_clarification to get the missing information
2. **Ambiguous Requirements** (`ambiguous_requirement`): Multiple valid interpretations exist
- Example: "Optimize the code" could mean performance, readability, or memory usage
- Example: "Make it better" is unclear what aspect to improve
- **REQUIRED ACTION**: Call ask_clarification to clarify the exact requirement
3. **Approach Choices** (`approach_choice`): Several valid approaches exist
- Example: "Add authentication" could use JWT, OAuth, session-based, or API keys
- Example: "Store data" could use database, files, cache, etc.
- **REQUIRED ACTION**: Call ask_clarification to let user choose the approach
4. **Risky Operations** (`risk_confirmation`): Destructive actions need confirmation
- Example: Deleting files, modifying production configs, database operations
- Example: Overwriting existing code or data
- **REQUIRED ACTION**: Call ask_clarification to get explicit confirmation
5. **Suggestions** (`suggestion`): You have a recommendation but want approval
- Example: "I recommend refactoring this code. Should I proceed?"
- **REQUIRED ACTION**: Call ask_clarification to get approval
**STRICT ENFORCEMENT:**
- ❌ DO NOT start working and then ask for clarification mid-execution - clarify FIRST
- ❌ DO NOT skip clarification for "efficiency" - accuracy matters more than speed
- ❌ DO NOT make assumptions when information is missing - ALWAYS ask
- ❌ DO NOT proceed with guesses - STOP and call ask_clarification first
- ✅ Analyze the request in thinking → Identify unclear aspects → Ask BEFORE any action
- ✅ If you identify the need for clarification in your thinking, you MUST call the tool IMMEDIATELY
- ✅ After calling ask_clarification, execution will be interrupted automatically
- ✅ Wait for user response - do NOT continue with assumptions
**How to Use:**
```python
ask_clarification(
question="Your specific question here?",
clarification_type="missing_info", # or other type
context="Why you need this information", # optional but recommended
options=["option1", "option2"] # optional, for choices
)
```
**Example:**
User: "Deploy the application"
You (thinking): Missing environment info - I MUST ask for clarification
You (action): ask_clarification(
question="Which environment should I deploy to?",
clarification_type="approach_choice",
context="I need to know the target environment for proper configuration",
options=["development", "staging", "production"]
)
[Execution stops - wait for user response]
User: "staging"
You: "Deploying to staging..." [proceed]
</clarification_system>
{skills_section}
{deferred_tools_section}
{subagent_section}
<working_directory existed="true">
- User uploads: `/mnt/user-data/uploads` - Files uploaded by the user (automatically listed in context)
- User workspace: `/mnt/user-data/workspace` - Working directory for temporary files
- Output files: `/mnt/user-data/outputs` - Final deliverables must be saved here
**File Management:**
- Uploaded files are automatically listed in the <uploaded_files> section before each request
- Use `read_file` tool to read uploaded files using their paths from the list
- For PDF, PPT, Excel, and Word files, converted Markdown versions (*.md) are available alongside originals
- All temporary work happens in `/mnt/user-data/workspace`
- Treat `/mnt/user-data/workspace` as your default current working directory for coding and file-editing tasks
- When writing scripts or commands that create/read files from the workspace, prefer relative paths such as `hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`
- Avoid hardcoding `/mnt/user-data/...` inside generated scripts when a relative path from the workspace is enough
- Final deliverables must be copied to `/mnt/user-data/outputs` and presented using `present_file` tool
{acp_section}
</working_directory>
<response_style>
- Clear and Concise: Avoid over-formatting unless requested
- Natural Tone: Use paragraphs and prose, not bullet points by default
- Action-Oriented: Focus on delivering results, not explaining processes
</response_style>
<citations>
**CRITICAL: Always include citations when using web search results**
- **When to Use**: MANDATORY after web_search, web_fetch, or any external information source
- **Format**: Use Markdown link format `[citation:TITLE](URL)` immediately after the claim
- **Placement**: Inline citations should appear right after the sentence or claim they support
- **Sources Section**: Also collect all citations in a "Sources" section at the end of reports
**Example - Inline Citations:**
```markdown
The key AI trends for 2026 include enhanced reasoning capabilities and multimodal integration
[citation:AI Trends 2026](https://techcrunch.com/ai-trends).
Recent breakthroughs in language models have also accelerated progress
[citation:OpenAI Research](https://openai.com/research).
```
**Example - Deep Research Report with Citations:**
```markdown
## Executive Summary
DeerFlow is an open-source AI agent framework that gained significant traction in early 2026
[citation:GitHub Repository](https://github.com/bytedance/deer-flow). The project focuses on
providing a production-ready agent system with sandbox execution and memory management
[citation:DeerFlow Documentation](https://deer-flow.dev/docs).
## Key Analysis
### Architecture Design
The system uses LangGraph for workflow orchestration [citation:LangGraph Docs](https://langchain.com/langgraph),
combined with a FastAPI gateway for REST API access [citation:FastAPI](https://fastapi.tiangolo.com).
## Sources
### Primary Sources
- [GitHub Repository](https://github.com/bytedance/deer-flow) - Official source code and documentation
- [DeerFlow Documentation](https://deer-flow.dev/docs) - Technical specifications
### Media Coverage
- [AI Trends 2026](https://techcrunch.com/ai-trends) - Industry analysis
```
**CRITICAL: Sources section format:**
- Every item in the Sources section MUST be a clickable markdown link with URL
- Use standard markdown link `[Title](URL) - Description` format (NOT `[citation:...]` format)
- The `[citation:Title](URL)` format is ONLY for inline citations within the report body
- ❌ WRONG: `GitHub 仓库 - 官方源代码和文档` (no URL!)
- ❌ WRONG in Sources: `[citation:GitHub Repository](url)` (citation prefix is for inline only!)
- ✅ RIGHT in Sources: `[GitHub Repository](https://github.com/bytedance/deer-flow) - 官方源代码和文档`
**WORKFLOW for Research Tasks:**
1. Use web_search to find sources → Extract {{title, url, snippet}} from results
2. Write content with inline citations: `claim [citation:Title](url)`
3. Collect all citations in a "Sources" section at the end
4. NEVER write claims without citations when sources are available
**CRITICAL RULES:**
- ❌ DO NOT write research content without citations
- ❌ DO NOT forget to extract URLs from search results
- ✅ ALWAYS add `[citation:Title](URL)` after claims from external sources
- ✅ ALWAYS include a "Sources" section listing all references
</citations>
<critical_reminders>
- **Clarification First**: ALWAYS clarify unclear/missing/ambiguous requirements BEFORE starting work - never assume or guess
{subagent_reminder}- Skill First: Always load the relevant skill before starting **complex** tasks.
- Progressive Loading: Load resources incrementally as referenced in skills
- Output Files: Final deliverables must be in `/mnt/user-data/outputs`
- Clarity: Be direct and helpful, avoid unnecessary meta-commentary
- Including Images and Mermaid: Images and Mermaid diagrams are always welcomed in the Markdown format, and you're encouraged to use `![Image Description](image_path)\n\n` or "```mermaid" to display images in response or Markdown files
- Multi-task: Better utilize parallel tool calling to call multiple tools at one time for better performance
- Language Consistency: Keep using the same language as user's
- Always Respond: Your thinking is internal. You MUST always provide a visible response to the user after thinking.
</critical_reminders>
"""
def _get_memory_context(agent_name: str | None = None) -> str:
"""Get memory context for injection into system prompt.
Args:
agent_name: If provided, loads per-agent memory. If None, loads global memory.
Returns:
Formatted memory context string wrapped in XML tags, or empty string if disabled.
"""
try:
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.config.memory_config import get_memory_config
config = get_memory_config()
if not config.enabled or not config.injection_enabled:
return ""
memory_data = get_memory_data(agent_name)
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
if not memory_content.strip():
return ""
return f"""<memory>
{memory_content}
</memory>
"""
except Exception as e:
logger.error("Failed to load memory context: %s", e)
return ""
@lru_cache(maxsize=32)
def _get_cached_skills_prompt_section(
skill_signature: tuple[tuple[str, str, str, str], ...],
available_skills_key: tuple[str, ...] | None,
container_base_path: str,
skill_evolution_section: str,
) -> str:
filtered = [(name, description, category, location) for name, description, category, location in skill_signature if available_skills_key is None or name in available_skills_key]
skills_list = ""
if filtered:
skill_items = "\n".join(
f" <skill>\n <name>{name}</name>\n <description>{description} {_skill_mutability_label(category)}</description>\n <location>{location}</location>\n </skill>"
for name, description, category, location in filtered
)
skills_list = f"<available_skills>\n{skill_items}\n</available_skills>"
return f"""<skill_system>
You have access to skills that provide optimized workflows for specific tasks. Each skill contains best practices, frameworks, and references to additional resources.
**Progressive Loading Pattern:**
1. When a user query matches a skill's use case, immediately call `read_file` on the skill's main file using the path attribute provided in the skill tag below
2. Read and understand the skill's workflow and instructions
3. The skill file contains references to external resources under the same folder
4. Load referenced resources only when needed during execution
5. Follow the skill's instructions precisely
**Skills are located at:** {container_base_path}
{skill_evolution_section}
{skills_list}
</skill_system>"""
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
"""Generate the skills prompt section with available skills list."""
skills = _get_enabled_skills()
try:
from deerflow.config import get_app_config
config = get_app_config()
container_base_path = config.skills.container_path
skill_evolution_enabled = config.skill_evolution.enabled
except Exception:
container_base_path = "/mnt/skills"
skill_evolution_enabled = False
if not skills and not skill_evolution_enabled:
return ""
if available_skills is not None and not any(skill.name in available_skills for skill in skills):
return ""
skill_signature = tuple((skill.name, skill.description, skill.category, skill.get_container_file_path(container_base_path)) for skill in skills)
available_key = tuple(sorted(available_skills)) if available_skills is not None else None
if not skill_signature and available_key is not None:
return ""
skill_evolution_section = _build_skill_evolution_section(skill_evolution_enabled)
return _get_cached_skills_prompt_section(skill_signature, available_key, container_base_path, skill_evolution_section)
def get_agent_soul(agent_name: str | None) -> str:
# Append SOUL.md (agent personality) if present
soul = load_agent_soul(agent_name)
if soul:
return f"<soul>\n{soul}\n</soul>\n" if soul else ""
return ""
def get_deferred_tools_prompt_section() -> str:
"""Generate <available-deferred-tools> block for the system prompt.
Lists only deferred tool names so the agent knows what exists
and can use tool_search to load them.
Returns empty string when tool_search is disabled or no tools are deferred.
"""
from deerflow.tools.builtins.tool_search import get_deferred_registry
try:
from deerflow.config import get_app_config
if not get_app_config().tool_search.enabled:
return ""
except Exception:
return ""
registry = get_deferred_registry()
if not registry:
return ""
names = "\n".join(e.name for e in registry.entries)
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
def _build_acp_section() -> str:
"""Build the ACP agent prompt section, only if ACP agents are configured."""
try:
from deerflow.config.acp_config import get_acp_agents
agents = get_acp_agents()
if not agents:
return ""
except Exception:
return ""
return (
"\n**ACP Agent Tasks (invoke_acp_agent):**\n"
"- ACP agents (e.g. codex, claude_code) run in their own independent workspace — NOT in `/mnt/user-data/`\n"
"- When writing prompts for ACP agents, describe the task only — do NOT reference `/mnt/user-data` paths\n"
"- ACP agent results are accessible at `/mnt/acp-workspace/` (read-only) — use `ls`, `read_file`, or `bash cp` to retrieve output files\n"
"- To deliver ACP output to the user: copy from `/mnt/acp-workspace/<file>` to `/mnt/user-data/outputs/<file>`, then use `present_file`"
)
def _build_custom_mounts_section() -> str:
"""Build a prompt section for explicitly configured sandbox mounts."""
try:
from deerflow.config import get_app_config
mounts = get_app_config().sandbox.mounts or []
except Exception:
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
return ""
if not mounts:
return ""
lines = []
for mount in mounts:
access = "read-only" if mount.read_only else "read-write"
lines.append(f"- Custom mount: `{mount.container_path}` - Host directory mapped into the sandbox ({access})")
mounts_list = "\n".join(lines)
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
# Get memory context
memory_context = _get_memory_context(agent_name)
# Include subagent section only if enabled (from runtime parameter)
n = max_concurrent_subagents
subagent_section = _build_subagent_section(n) if subagent_enabled else ""
# Add subagent reminder to critical_reminders if enabled
subagent_reminder = (
"- **Orchestrator Mode**: You are a task orchestrator - decompose complex tasks into parallel sub-tasks. "
f"**HARD LIMIT: max {n} `task` calls per response.** "
f"If >{n} sub-tasks, split into sequential batches of ≤{n}. Synthesize after ALL batches complete.\n"
if subagent_enabled
else ""
)
# Add subagent thinking guidance if enabled
subagent_thinking = (
"- **DECOMPOSITION CHECK: Can this task be broken into 2+ parallel sub-tasks? If YES, COUNT them. "
f"If count > {n}, you MUST plan batches of ≤{n} and only launch the FIRST batch now. "
f"NEVER launch more than {n} `task` calls in one response.**\n"
if subagent_enabled
else ""
)
# Get skills section
skills_section = get_skills_prompt_section(available_skills)
# Get deferred tools section (tool_search)
deferred_tools_section = get_deferred_tools_prompt_section()
# Build ACP agent section only if ACP agents are configured
acp_section = _build_acp_section()
custom_mounts_section = _build_custom_mounts_section()
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
# Format the prompt with dynamic skills and memory
prompt = SYSTEM_PROMPT_TEMPLATE.format(
agent_name=agent_name or "DeerFlow 2.0",
soul=get_agent_soul(agent_name),
skills_section=skills_section,
deferred_tools_section=deferred_tools_section,
memory_context=memory_context,
subagent_section=subagent_section,
subagent_reminder=subagent_reminder,
subagent_thinking=subagent_thinking,
acp_section=acp_and_mounts_section,
)
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"

View File

@@ -0,0 +1,57 @@
"""Memory module for DeerFlow.
This module provides a global memory mechanism that:
- Stores user context and conversation history in memory.json
- Uses LLM to summarize and extract facts from conversations
- Injects relevant memory into system prompts for personalized responses
"""
from deerflow.agents.memory.prompt import (
FACT_EXTRACTION_PROMPT,
MEMORY_UPDATE_PROMPT,
format_conversation_for_update,
format_memory_for_injection,
)
from deerflow.agents.memory.queue import (
ConversationContext,
MemoryUpdateQueue,
get_memory_queue,
reset_memory_queue,
)
from deerflow.agents.memory.storage import (
FileMemoryStorage,
MemoryStorage,
get_memory_storage,
)
from deerflow.agents.memory.updater import (
MemoryUpdater,
clear_memory_data,
delete_memory_fact,
get_memory_data,
reload_memory_data,
update_memory_from_conversation,
)
__all__ = [
# Prompt utilities
"MEMORY_UPDATE_PROMPT",
"FACT_EXTRACTION_PROMPT",
"format_memory_for_injection",
"format_conversation_for_update",
# Queue
"ConversationContext",
"MemoryUpdateQueue",
"get_memory_queue",
"reset_memory_queue",
# Storage
"MemoryStorage",
"FileMemoryStorage",
"get_memory_storage",
# Updater
"MemoryUpdater",
"clear_memory_data",
"delete_memory_fact",
"get_memory_data",
"reload_memory_data",
"update_memory_from_conversation",
]

View File

@@ -0,0 +1,363 @@
"""Prompt templates for memory update and injection."""
import math
import re
from typing import Any
try:
import tiktoken
TIKTOKEN_AVAILABLE = True
except ImportError:
TIKTOKEN_AVAILABLE = False
# Prompt template for updating memory based on conversation
MEMORY_UPDATE_PROMPT = """You are a memory management system. Your task is to analyze a conversation and update the user's memory profile.
Current Memory State:
<current_memory>
{current_memory}
</current_memory>
New Conversation to Process:
<conversation>
{conversation}
</conversation>
Instructions:
1. Analyze the conversation for important information about the user
2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies)
3. Update the memory sections as needed following the detailed length guidelines below
Before extracting facts, perform a structured reflection on the conversation:
1. Error/Retry Detection: Did the agent encounter errors, require retries, or produce incorrect results?
If yes, record the root cause and correct approach as a high-confidence fact with category "correction".
2. User Correction Detection: Did the user correct the agent's direction, understanding, or output?
If yes, record the correct interpretation or approach as a high-confidence fact with category "correction".
Include what went wrong in "sourceError" only when category is "correction" and the mistake is explicit in the conversation.
3. Project Constraint Discovery: Were any project-specific constraints discovered during the conversation?
If yes, record them as facts with the most appropriate category and confidence.
{correction_hint}
Memory Section Guidelines:
**User Context** (Current state - concise summaries):
- workContext: Professional role, company, key projects, main technologies (2-3 sentences)
Example: Core contributor, project names with metrics (16k+ stars), technical stack
- personalContext: Languages, communication preferences, key interests (1-2 sentences)
Example: Bilingual capabilities, specific interest areas, expertise domains
- topOfMind: Multiple ongoing focus areas and priorities (3-5 sentences, detailed paragraph)
Example: Primary project work, parallel technical investigations, ongoing learning/tracking
Include: Active implementation work, troubleshooting issues, market/research interests
Note: This captures SEVERAL concurrent focus areas, not just one task
**History** (Temporal context - rich paragraphs):
- recentMonths: Detailed summary of recent activities (4-6 sentences or 1-2 paragraphs)
Timeline: Last 1-3 months of interactions
Include: Technologies explored, projects worked on, problems solved, interests demonstrated
- earlierContext: Important historical patterns (3-5 sentences or 1 paragraph)
Timeline: 3-12 months ago
Include: Past projects, learning journeys, established patterns
- longTermBackground: Persistent background and foundational context (2-4 sentences)
Timeline: Overall/foundational information
Include: Core expertise, longstanding interests, fundamental working style
**Facts Extraction**:
- Extract specific, quantifiable details (e.g., "16k+ GitHub stars", "200+ datasets")
- Include proper nouns (company names, project names, technology names)
- Preserve technical terminology and version numbers
- Categories:
* preference: Tools, styles, approaches user prefers/dislikes
* knowledge: Specific expertise, technologies mastered, domain knowledge
* context: Background facts (job title, projects, locations, languages)
* behavior: Working patterns, communication habits, problem-solving approaches
* goal: Stated objectives, learning targets, project ambitions
* correction: Explicit agent mistakes or user corrections, including the correct approach
- Confidence levels:
* 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y")
* 0.7-0.8: Strongly implied from actions/discussions
* 0.5-0.6: Inferred patterns (use sparingly, only for clear patterns)
**What Goes Where**:
- workContext: Current job, active projects, primary tech stack
- personalContext: Languages, personality, interests outside direct work tasks
- topOfMind: Multiple ongoing priorities and focus areas user cares about recently (gets updated most frequently)
Should capture 3-5 concurrent themes: main work, side explorations, learning/tracking interests
- recentMonths: Detailed account of recent technical explorations and work
- earlierContext: Patterns from slightly older interactions still relevant
- longTermBackground: Unchanging foundational facts about the user
**Multilingual Content**:
- Preserve original language for proper nouns and company names
- Keep technical terms in their original form (DeepSeek, LangGraph, etc.)
- Note language capabilities in personalContext
Output Format (JSON):
{{
"user": {{
"workContext": {{ "summary": "...", "shouldUpdate": true/false }},
"personalContext": {{ "summary": "...", "shouldUpdate": true/false }},
"topOfMind": {{ "summary": "...", "shouldUpdate": true/false }}
}},
"history": {{
"recentMonths": {{ "summary": "...", "shouldUpdate": true/false }},
"earlierContext": {{ "summary": "...", "shouldUpdate": true/false }},
"longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }}
}},
"newFacts": [
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
],
"factsToRemove": ["fact_id_1", "fact_id_2"]
}}
Important Rules:
- Only set shouldUpdate=true if there's meaningful new information
- Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs)
- Include specific metrics, version numbers, and proper nouns in facts
- Only add facts that are clearly stated (0.9+) or strongly implied (0.7+)
- Use category "correction" for explicit agent mistakes or user corrections; assign confidence >= 0.95 when the correction is explicit
- Include "sourceError" only for explicit correction facts when the prior mistake or wrong approach is clearly stated; omit it otherwise
- Remove facts that are contradicted by new information
- When updating topOfMind, integrate new focus areas while removing completed/abandoned ones
Keep 3-5 concurrent focus themes that are still active and relevant
- For history sections, integrate new information chronologically into appropriate time period
- Preserve technical accuracy - keep exact names of technologies, companies, projects
- Focus on information useful for future interactions and personalization
- IMPORTANT: Do NOT record file upload events in memory. Uploaded files are
session-specific and ephemeral — they will not be accessible in future sessions.
Recording upload events causes confusion in subsequent conversations.
Return ONLY valid JSON, no explanation or markdown."""
# Prompt template for extracting facts from a single message
FACT_EXTRACTION_PROMPT = """Extract factual information about the user from this message.
Message:
{message}
Extract facts in this JSON format:
{{
"facts": [
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
]
}}
Categories:
- preference: User preferences (likes/dislikes, styles, tools)
- knowledge: User's expertise or knowledge areas
- context: Background context (location, job, projects)
- behavior: Behavioral patterns
- goal: User's goals or objectives
- correction: Explicit corrections or mistakes to avoid repeating
Rules:
- Only extract clear, specific facts
- Confidence should reflect certainty (explicit statement = 0.9+, implied = 0.6-0.8)
- Skip vague or temporary information
Return ONLY valid JSON."""
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
"""Count tokens in text using tiktoken.
Args:
text: The text to count tokens for.
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
Returns:
The number of tokens in the text.
"""
if not TIKTOKEN_AVAILABLE:
# Fallback to character-based estimation if tiktoken is not available
return len(text) // 4
try:
encoding = tiktoken.get_encoding(encoding_name)
return len(encoding.encode(text))
except Exception:
# Fallback to character-based estimation on error
return len(text) // 4
def _coerce_confidence(value: Any, default: float = 0.0) -> float:
"""Coerce a confidence-like value to a bounded float in [0, 1].
Non-finite values (NaN, inf, -inf) are treated as invalid and fall back
to the default before clamping, preventing them from dominating ranking.
The ``default`` parameter is assumed to be a finite value.
"""
try:
confidence = float(value)
except (TypeError, ValueError):
return max(0.0, min(1.0, default))
if not math.isfinite(confidence):
return max(0.0, min(1.0, default))
return max(0.0, min(1.0, confidence))
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
"""Format memory data for injection into system prompt.
Args:
memory_data: The memory data dictionary.
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
Returns:
Formatted memory string for system prompt injection.
"""
if not memory_data:
return ""
sections = []
# Format user context
user_data = memory_data.get("user", {})
if user_data:
user_sections = []
work_ctx = user_data.get("workContext", {})
if work_ctx.get("summary"):
user_sections.append(f"Work: {work_ctx['summary']}")
personal_ctx = user_data.get("personalContext", {})
if personal_ctx.get("summary"):
user_sections.append(f"Personal: {personal_ctx['summary']}")
top_of_mind = user_data.get("topOfMind", {})
if top_of_mind.get("summary"):
user_sections.append(f"Current Focus: {top_of_mind['summary']}")
if user_sections:
sections.append("User Context:\n" + "\n".join(f"- {s}" for s in user_sections))
# Format history
history_data = memory_data.get("history", {})
if history_data:
history_sections = []
recent = history_data.get("recentMonths", {})
if recent.get("summary"):
history_sections.append(f"Recent: {recent['summary']}")
earlier = history_data.get("earlierContext", {})
if earlier.get("summary"):
history_sections.append(f"Earlier: {earlier['summary']}")
background = history_data.get("longTermBackground", {})
if background.get("summary"):
history_sections.append(f"Background: {background['summary']}")
if history_sections:
sections.append("History:\n" + "\n".join(f"- {s}" for s in history_sections))
# Format facts (sorted by confidence; include as many as token budget allows)
facts_data = memory_data.get("facts", [])
if isinstance(facts_data, list) and facts_data:
ranked_facts = sorted(
(f for f in facts_data if isinstance(f, dict) and isinstance(f.get("content"), str) and f.get("content").strip()),
key=lambda fact: _coerce_confidence(fact.get("confidence"), default=0.0),
reverse=True,
)
# Compute token count for existing sections once, then account
# incrementally for each fact line to avoid full-string re-tokenization.
base_text = "\n\n".join(sections)
base_tokens = _count_tokens(base_text) if base_text else 0
# Account for the separator between existing sections and the facts section.
facts_header = "Facts:\n"
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header)
running_tokens = base_tokens + separator_tokens
fact_lines: list[str] = []
for fact in ranked_facts:
content_value = fact.get("content")
if not isinstance(content_value, str):
continue
content = content_value.strip()
if not content:
continue
category = str(fact.get("category", "context")).strip() or "context"
confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
source_error = fact.get("sourceError")
if category == "correction" and isinstance(source_error, str) and source_error.strip():
line = f"- [{category} | {confidence:.2f}] {content} (avoid: {source_error.strip()})"
else:
line = f"- [{category} | {confidence:.2f}] {content}"
# Each additional line is preceded by a newline (except the first).
line_text = ("\n" + line) if fact_lines else line
line_tokens = _count_tokens(line_text)
if running_tokens + line_tokens <= max_tokens:
fact_lines.append(line)
running_tokens += line_tokens
else:
break
if fact_lines:
sections.append("Facts:\n" + "\n".join(fact_lines))
if not sections:
return ""
result = "\n\n".join(sections)
# Use accurate token counting with tiktoken
token_count = _count_tokens(result)
if token_count > max_tokens:
# Truncate to fit within token limit
# Estimate characters to remove based on token ratio
char_per_token = len(result) / token_count
target_chars = int(max_tokens * char_per_token * 0.95) # 95% to leave margin
result = result[:target_chars] + "\n..."
return result
def format_conversation_for_update(messages: list[Any]) -> str:
"""Format conversation messages for memory update prompt.
Args:
messages: List of conversation messages.
Returns:
Formatted conversation string.
"""
lines = []
for msg in messages:
role = getattr(msg, "type", "unknown")
content = getattr(msg, "content", str(msg))
# Handle content that might be a list (multimodal)
if isinstance(content, list):
text_parts = []
for p in content:
if isinstance(p, str):
text_parts.append(p)
elif isinstance(p, dict):
text_val = p.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
content = " ".join(text_parts) if text_parts else str(content)
# Strip uploaded_files tags from human messages to avoid persisting
# ephemeral file path info into long-term memory. Skip the turn entirely
# when nothing remains after stripping (upload-only message).
if role == "human":
content = re.sub(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", "", str(content)).strip()
if not content:
continue
# Truncate very long messages
if len(str(content)) > 1000:
content = str(content)[:1000] + "..."
if role == "human":
lines.append(f"User: {content}")
elif role == "ai":
lines.append(f"Assistant: {content}")
return "\n\n".join(lines)

View File

@@ -0,0 +1,219 @@
"""Memory update queue with debounce mechanism."""
import logging
import threading
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any
from deerflow.config.memory_config import get_memory_config
logger = logging.getLogger(__name__)
@dataclass
class ConversationContext:
"""Context for a conversation to be processed for memory update."""
thread_id: str
messages: list[Any]
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
agent_name: str | None = None
correction_detected: bool = False
reinforcement_detected: bool = False
class MemoryUpdateQueue:
"""Queue for memory updates with debounce mechanism.
This queue collects conversation contexts and processes them after
a configurable debounce period. Multiple conversations received within
the debounce window are batched together.
"""
def __init__(self):
"""Initialize the memory update queue."""
self._queue: list[ConversationContext] = []
self._lock = threading.Lock()
self._timer: threading.Timer | None = None
self._processing = False
def add(
self,
thread_id: str,
messages: list[Any],
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
"""Add a conversation to the update queue.
Args:
thread_id: The thread ID.
messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
"""
config = get_memory_config()
if not config.enabled:
return
with self._lock:
existing_context = next(
(context for context in self._queue if context.thread_id == thread_id),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)
# Check if this thread already has a pending update
# If so, replace it with the newer one
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue.append(context)
# Reset or start the debounce timer
self._reset_timer()
logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue))
def _reset_timer(self) -> None:
"""Reset the debounce timer."""
config = get_memory_config()
# Cancel existing timer if any
if self._timer is not None:
self._timer.cancel()
# Start new timer
self._timer = threading.Timer(
config.debounce_seconds,
self._process_queue,
)
self._timer.daemon = True
self._timer.start()
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
def _process_queue(self) -> None:
"""Process all queued conversation contexts."""
# Import here to avoid circular dependency
from deerflow.agents.memory.updater import MemoryUpdater
with self._lock:
if self._processing:
# Already processing, reschedule
self._reset_timer()
return
if not self._queue:
return
self._processing = True
contexts_to_process = self._queue.copy()
self._queue.clear()
self._timer = None
logger.info("Processing %d queued memory updates", len(contexts_to_process))
try:
updater = MemoryUpdater()
for context in contexts_to_process:
try:
logger.info("Updating memory for thread %s", context.thread_id)
success = updater.update_memory(
messages=context.messages,
thread_id=context.thread_id,
agent_name=context.agent_name,
correction_detected=context.correction_detected,
reinforcement_detected=context.reinforcement_detected,
)
if success:
logger.info("Memory updated successfully for thread %s", context.thread_id)
else:
logger.warning("Memory update skipped/failed for thread %s", context.thread_id)
except Exception as e:
logger.error("Error updating memory for thread %s: %s", context.thread_id, e)
# Small delay between updates to avoid rate limiting
if len(contexts_to_process) > 1:
time.sleep(0.5)
finally:
with self._lock:
self._processing = False
def flush(self) -> None:
"""Force immediate processing of the queue.
This is useful for testing or graceful shutdown.
"""
with self._lock:
if self._timer is not None:
self._timer.cancel()
self._timer = None
self._process_queue()
def clear(self) -> None:
"""Clear the queue without processing.
This is useful for testing.
"""
with self._lock:
if self._timer is not None:
self._timer.cancel()
self._timer = None
self._queue.clear()
self._processing = False
@property
def pending_count(self) -> int:
"""Get the number of pending updates."""
with self._lock:
return len(self._queue)
@property
def is_processing(self) -> bool:
"""Check if the queue is currently being processed."""
with self._lock:
return self._processing
# Global singleton instance
_memory_queue: MemoryUpdateQueue | None = None
_queue_lock = threading.Lock()
def get_memory_queue() -> MemoryUpdateQueue:
"""Get the global memory update queue singleton.
Returns:
The memory update queue instance.
"""
global _memory_queue
with _queue_lock:
if _memory_queue is None:
_memory_queue = MemoryUpdateQueue()
return _memory_queue
def reset_memory_queue() -> None:
"""Reset the global memory queue.
This is useful for testing.
"""
global _memory_queue
with _queue_lock:
if _memory_queue is not None:
_memory_queue.clear()
_memory_queue = None

View File

@@ -0,0 +1,205 @@
"""Memory storage providers."""
import abc
import json
import logging
import threading
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.memory_config import get_memory_config
from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__)
def utc_now_iso_z() -> str:
"""Current UTC time as ISO-8601 with ``Z`` suffix (matches prior naive-UTC output)."""
return datetime.now(UTC).isoformat().removesuffix("+00:00") + "Z"
def create_empty_memory() -> dict[str, Any]:
"""Create an empty memory structure."""
return {
"version": "1.0",
"lastUpdated": utc_now_iso_z(),
"user": {
"workContext": {"summary": "", "updatedAt": ""},
"personalContext": {"summary": "", "updatedAt": ""},
"topOfMind": {"summary": "", "updatedAt": ""},
},
"history": {
"recentMonths": {"summary": "", "updatedAt": ""},
"earlierContext": {"summary": "", "updatedAt": ""},
"longTermBackground": {"summary": "", "updatedAt": ""},
},
"facts": [],
}
class MemoryStorage(abc.ABC):
"""Abstract base class for memory storage providers."""
@abc.abstractmethod
def load(self, agent_name: str | None = None) -> dict[str, Any]:
"""Load memory data for the given agent."""
pass
@abc.abstractmethod
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
"""Force reload memory data for the given agent."""
pass
@abc.abstractmethod
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
"""Save memory data for the given agent."""
pass
class FileMemoryStorage(MemoryStorage):
"""File-based memory storage provider."""
def __init__(self):
"""Initialize the file memory storage."""
# Per-agent memory cache: keyed by agent_name (None = global)
# Value: (memory_data, file_mtime)
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
def _validate_agent_name(self, agent_name: str) -> None:
"""Validate that the agent name is safe to use in filesystem paths.
Uses the repository's established AGENT_NAME_PATTERN to ensure consistency
across the codebase and prevent path traversal or other problematic characters.
"""
if not agent_name:
raise ValueError("Agent name must be a non-empty string.")
if not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
def _get_memory_file_path(self, agent_name: str | None = None) -> Path:
"""Get the path to the memory file."""
if agent_name is not None:
self._validate_agent_name(agent_name)
return get_paths().agent_memory_file(agent_name)
config = get_memory_config()
if config.storage_path:
p = Path(config.storage_path)
return p if p.is_absolute() else get_paths().base_dir / p
return get_paths().memory_file
def _load_memory_from_file(self, agent_name: str | None = None) -> dict[str, Any]:
"""Load memory data from file."""
file_path = self._get_memory_file_path(agent_name)
if not file_path.exists():
return create_empty_memory()
try:
with open(file_path, encoding="utf-8") as f:
data = json.load(f)
return data
except (json.JSONDecodeError, OSError) as e:
logger.warning("Failed to load memory file: %s", e)
return create_empty_memory()
def load(self, agent_name: str | None = None) -> dict[str, Any]:
"""Load memory data (cached with file modification time check)."""
file_path = self._get_memory_file_path(agent_name)
try:
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError:
current_mtime = None
cached = self._memory_cache.get(agent_name)
if cached is None or cached[1] != current_mtime:
memory_data = self._load_memory_from_file(agent_name)
self._memory_cache[agent_name] = (memory_data, current_mtime)
return memory_data
return cached[0]
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
"""Reload memory data from file, forcing cache invalidation."""
file_path = self._get_memory_file_path(agent_name)
memory_data = self._load_memory_from_file(agent_name)
try:
mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError:
mtime = None
self._memory_cache[agent_name] = (memory_data, mtime)
return memory_data
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
"""Save memory data to file and update cache."""
file_path = self._get_memory_file_path(agent_name)
try:
file_path.parent.mkdir(parents=True, exist_ok=True)
memory_data["lastUpdated"] = utc_now_iso_z()
temp_path = file_path.with_suffix(".tmp")
with open(temp_path, "w", encoding="utf-8") as f:
json.dump(memory_data, f, indent=2, ensure_ascii=False)
temp_path.replace(file_path)
try:
mtime = file_path.stat().st_mtime
except OSError:
mtime = None
self._memory_cache[agent_name] = (memory_data, mtime)
logger.info("Memory saved to %s", file_path)
return True
except OSError as e:
logger.error("Failed to save memory file: %s", e)
return False
_storage_instance: MemoryStorage | None = None
_storage_lock = threading.Lock()
def get_memory_storage() -> MemoryStorage:
"""Get the configured memory storage instance."""
global _storage_instance
if _storage_instance is not None:
return _storage_instance
with _storage_lock:
if _storage_instance is not None:
return _storage_instance
config = get_memory_config()
storage_class_path = config.storage_class
try:
module_path, class_name = storage_class_path.rsplit(".", 1)
import importlib
module = importlib.import_module(module_path)
storage_class = getattr(module, class_name)
# Validate that the configured storage is a MemoryStorage implementation
if not isinstance(storage_class, type):
raise TypeError(f"Configured memory storage '{storage_class_path}' is not a class: {storage_class!r}")
if not issubclass(storage_class, MemoryStorage):
raise TypeError(f"Configured memory storage '{storage_class_path}' is not a subclass of MemoryStorage")
_storage_instance = storage_class()
except Exception as e:
logger.error(
"Failed to load memory storage %s, falling back to FileMemoryStorage: %s",
storage_class_path,
e,
)
_storage_instance = FileMemoryStorage()
return _storage_instance

View File

@@ -0,0 +1,472 @@
"""Memory updater for reading, writing, and updating memory data."""
import json
import logging
import math
import re
import uuid
from typing import Any
from deerflow.agents.memory.prompt import (
MEMORY_UPDATE_PROMPT,
format_conversation_for_update,
)
from deerflow.agents.memory.storage import (
create_empty_memory,
get_memory_storage,
utc_now_iso_z,
)
from deerflow.config.memory_config import get_memory_config
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
def _create_empty_memory() -> dict[str, Any]:
"""Backward-compatible wrapper around the storage-layer empty-memory factory."""
return create_empty_memory()
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
"""Backward-compatible wrapper around the configured memory storage save path."""
return get_memory_storage().save(memory_data, agent_name)
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
"""Get the current memory data via storage provider."""
return get_memory_storage().load(agent_name)
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
"""Reload memory data via storage provider."""
return get_memory_storage().reload(agent_name)
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]:
"""Persist imported memory data via storage provider.
Args:
memory_data: Full memory payload to persist.
agent_name: If provided, imports into per-agent memory.
Returns:
The saved memory data after storage normalization.
Raises:
OSError: If persisting the imported memory fails.
"""
storage = get_memory_storage()
if not storage.save(memory_data, agent_name):
raise OSError("Failed to save imported memory data")
return storage.load(agent_name)
def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]:
"""Clear all stored memory data and persist an empty structure."""
cleared_memory = create_empty_memory()
if not _save_memory_to_file(cleared_memory, agent_name):
raise OSError("Failed to save cleared memory data")
return cleared_memory
def _validate_confidence(confidence: float) -> float:
"""Validate persisted fact confidence so stored JSON stays standards-compliant."""
if not math.isfinite(confidence) or confidence < 0 or confidence > 1:
raise ValueError("confidence")
return confidence
def create_memory_fact(
content: str,
category: str = "context",
confidence: float = 0.5,
agent_name: str | None = None,
) -> dict[str, Any]:
"""Create a new fact and persist the updated memory data."""
normalized_content = content.strip()
if not normalized_content:
raise ValueError("content")
normalized_category = category.strip() or "context"
validated_confidence = _validate_confidence(confidence)
now = utc_now_iso_z()
memory_data = get_memory_data(agent_name)
updated_memory = dict(memory_data)
facts = list(memory_data.get("facts", []))
facts.append(
{
"id": f"fact_{uuid.uuid4().hex[:8]}",
"content": normalized_content,
"category": normalized_category,
"confidence": validated_confidence,
"createdAt": now,
"source": "manual",
}
)
updated_memory["facts"] = facts
if not _save_memory_to_file(updated_memory, agent_name):
raise OSError("Failed to save memory data after creating fact")
return updated_memory
def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]:
"""Delete a fact by its id and persist the updated memory data."""
memory_data = get_memory_data(agent_name)
facts = memory_data.get("facts", [])
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
if len(updated_facts) == len(facts):
raise KeyError(fact_id)
updated_memory = dict(memory_data)
updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name):
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
return updated_memory
def update_memory_fact(
fact_id: str,
content: str | None = None,
category: str | None = None,
confidence: float | None = None,
agent_name: str | None = None,
) -> dict[str, Any]:
"""Update an existing fact and persist the updated memory data."""
memory_data = get_memory_data(agent_name)
updated_memory = dict(memory_data)
updated_facts: list[dict[str, Any]] = []
found = False
for fact in memory_data.get("facts", []):
if fact.get("id") == fact_id:
found = True
updated_fact = dict(fact)
if content is not None:
normalized_content = content.strip()
if not normalized_content:
raise ValueError("content")
updated_fact["content"] = normalized_content
if category is not None:
updated_fact["category"] = category.strip() or "context"
if confidence is not None:
updated_fact["confidence"] = _validate_confidence(confidence)
updated_facts.append(updated_fact)
else:
updated_facts.append(fact)
if not found:
raise KeyError(fact_id)
updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name):
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
return updated_memory
def _extract_text(content: Any) -> str:
"""Extract plain text from LLM response content (str or list of content blocks).
Modern LLMs may return structured content as a list of blocks instead of a
plain string, e.g. [{"type": "text", "text": "..."}]. Using str() on such
content produces Python repr instead of the actual text, breaking JSON
parsing downstream.
String chunks are concatenated without separators to avoid corrupting
chunked JSON/text payloads. Dict-based text blocks are treated as full text
blocks and joined with newlines for readability.
"""
if isinstance(content, str):
return content
if isinstance(content, list):
pieces: list[str] = []
pending_str_parts: list[str] = []
def flush_pending_str_parts() -> None:
if pending_str_parts:
pieces.append("".join(pending_str_parts))
pending_str_parts.clear()
for block in content:
if isinstance(block, str):
pending_str_parts.append(block)
elif isinstance(block, dict):
flush_pending_str_parts()
text_val = block.get("text")
if isinstance(text_val, str):
pieces.append(text_val)
flush_pending_str_parts()
return "\n".join(pieces)
return str(content)
# Matches sentences that describe a file-upload *event* rather than general
# file-related work. Deliberately narrow to avoid removing legitimate facts
# such as "User works with CSV files" or "prefers PDF export".
_UPLOAD_SENTENCE_RE = re.compile(
r"[^.!?]*\b(?:"
r"upload(?:ed|ing)?(?:\s+\w+){0,3}\s+(?:file|files?|document|documents?|attachment|attachments?)"
r"|file\s+upload"
r"|/mnt/user-data/uploads/"
r"|<uploaded_files>"
r")[^.!?]*[.!?]?\s*",
re.IGNORECASE,
)
def _strip_upload_mentions_from_memory(memory_data: dict[str, Any]) -> dict[str, Any]:
"""Remove sentences about file uploads from all memory summaries and facts.
Uploaded files are session-scoped; persisting upload events in long-term
memory causes the agent to search for non-existent files in future sessions.
"""
# Scrub summaries in user/history sections
for section in ("user", "history"):
section_data = memory_data.get(section, {})
for _key, val in section_data.items():
if isinstance(val, dict) and "summary" in val:
cleaned = _UPLOAD_SENTENCE_RE.sub("", val["summary"]).strip()
cleaned = re.sub(r" +", " ", cleaned)
val["summary"] = cleaned
# Also remove any facts that describe upload events
facts = memory_data.get("facts", [])
if facts:
memory_data["facts"] = [f for f in facts if not _UPLOAD_SENTENCE_RE.search(f.get("content", ""))]
return memory_data
def _fact_content_key(content: Any) -> str | None:
if not isinstance(content, str):
return None
stripped = content.strip()
if not stripped:
return None
return stripped.casefold()
class MemoryUpdater:
"""Updates memory using LLM based on conversation context."""
def __init__(self, model_name: str | None = None):
"""Initialize the memory updater.
Args:
model_name: Optional model name to use. If None, uses config or default.
"""
self._model_name = model_name
def _get_model(self):
"""Get the model for memory updates."""
config = get_memory_config()
model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False)
def update_memory(
self,
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Update memory based on conversation messages.
Args:
messages: List of conversation messages.
thread_id: Optional thread ID for tracking source.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
Returns:
True if update was successful, False otherwise.
"""
config = get_memory_config()
if not config.enabled:
return False
if not messages:
return False
try:
# Get current memory
current_memory = get_memory_data(agent_name)
# Format conversation for prompt
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return False
# Build prompt
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
)
# Call LLM
model = self._get_model()
response = model.invoke(prompt)
response_text = _extract_text(response.content).strip()
# Parse response
# Remove markdown code blocks if present
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Apply updates
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
# Strip file-upload mentions from all summaries before saving.
# Uploaded files are session-scoped and won't exist in future sessions,
# so recording upload events in long-term memory causes the agent to
# try (and fail) to locate those files in subsequent conversations.
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
# Save
return get_memory_storage().save(updated_memory, agent_name)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
return False
except Exception as e:
logger.exception("Memory update failed: %s", e)
return False
def _apply_updates(
self,
current_memory: dict[str, Any],
update_data: dict[str, Any],
thread_id: str | None = None,
) -> dict[str, Any]:
"""Apply LLM-generated updates to memory.
Args:
current_memory: Current memory data.
update_data: Updates from LLM.
thread_id: Optional thread ID for tracking.
Returns:
Updated memory data.
"""
config = get_memory_config()
now = utc_now_iso_z()
# Update user sections
user_updates = update_data.get("user", {})
for section in ["workContext", "personalContext", "topOfMind"]:
section_data = user_updates.get(section, {})
if section_data.get("shouldUpdate") and section_data.get("summary"):
current_memory["user"][section] = {
"summary": section_data["summary"],
"updatedAt": now,
}
# Update history sections
history_updates = update_data.get("history", {})
for section in ["recentMonths", "earlierContext", "longTermBackground"]:
section_data = history_updates.get(section, {})
if section_data.get("shouldUpdate") and section_data.get("summary"):
current_memory["history"][section] = {
"summary": section_data["summary"],
"updatedAt": now,
}
# Remove facts
facts_to_remove = set(update_data.get("factsToRemove", []))
if facts_to_remove:
current_memory["facts"] = [f for f in current_memory.get("facts", []) if f.get("id") not in facts_to_remove]
# Add new facts
existing_fact_keys = {fact_key for fact_key in (_fact_content_key(fact.get("content")) for fact in current_memory.get("facts", [])) if fact_key is not None}
new_facts = update_data.get("newFacts", [])
for fact in new_facts:
confidence = fact.get("confidence", 0.5)
if confidence >= config.fact_confidence_threshold:
raw_content = fact.get("content", "")
if not isinstance(raw_content, str):
continue
normalized_content = raw_content.strip()
fact_key = _fact_content_key(normalized_content)
if fact_key is not None and fact_key in existing_fact_keys:
continue
fact_entry = {
"id": f"fact_{uuid.uuid4().hex[:8]}",
"content": normalized_content,
"category": fact.get("category", "context"),
"confidence": confidence,
"createdAt": now,
"source": thread_id or "unknown",
}
source_error = fact.get("sourceError")
if isinstance(source_error, str):
normalized_source_error = source_error.strip()
if normalized_source_error:
fact_entry["sourceError"] = normalized_source_error
current_memory["facts"].append(fact_entry)
if fact_key is not None:
existing_fact_keys.add(fact_key)
# Enforce max facts limit
if len(current_memory["facts"]) > config.max_facts:
# Sort by confidence and keep top ones
current_memory["facts"] = sorted(
current_memory["facts"],
key=lambda f: f.get("confidence", 0),
reverse=True,
)[: config.max_facts]
return current_memory
def update_memory_from_conversation(
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Convenience function to update memory from a conversation.
Args:
messages: List of conversation messages.
thread_id: Optional thread ID.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
Returns:
True if successful, False otherwise.
"""
updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)

View File

@@ -0,0 +1,191 @@
"""Middleware for intercepting clarification requests and presenting them to the user."""
import json
import logging
from collections.abc import Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.graph import END
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__)
class ClarificationMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
pass
class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
"""Intercepts clarification tool calls and interrupts execution to present questions to the user.
When the model calls the `ask_clarification` tool, this middleware:
1. Intercepts the tool call before execution
2. Extracts the clarification question and metadata
3. Formats a user-friendly message
4. Returns a Command that interrupts execution and presents the question
5. Waits for user response before continuing
This replaces the tool-based approach where clarification continued the conversation flow.
"""
state_schema = ClarificationMiddlewareState
def _is_chinese(self, text: str) -> bool:
"""Check if text contains Chinese characters.
Args:
text: Text to check
Returns:
True if text contains Chinese characters
"""
return any("\u4e00" <= char <= "\u9fff" for char in text)
def _format_clarification_message(self, args: dict) -> str:
"""Format the clarification arguments into a user-friendly message.
Args:
args: The tool call arguments containing clarification details
Returns:
Formatted message string
"""
question = args.get("question", "")
clarification_type = args.get("clarification_type", "missing_info")
context = args.get("context")
options = args.get("options", [])
# Some models (e.g. Qwen3-Max) serialize array parameters as JSON strings
# instead of native arrays. Deserialize and normalize so `options`
# is always a list for the rendering logic below.
if isinstance(options, str):
try:
options = json.loads(options)
except (json.JSONDecodeError, TypeError):
options = [options]
if options is None:
options = []
elif not isinstance(options, list):
options = [options]
# Type-specific icons
type_icons = {
"missing_info": "",
"ambiguous_requirement": "🤔",
"approach_choice": "🔀",
"risk_confirmation": "⚠️",
"suggestion": "💡",
}
icon = type_icons.get(clarification_type, "")
# Build the message naturally
message_parts = []
# Add icon and question together for a more natural flow
if context:
# If there's context, present it first as background
message_parts.append(f"{icon} {context}")
message_parts.append(f"\n{question}")
else:
# Just the question with icon
message_parts.append(f"{icon} {question}")
# Add options in a cleaner format
if options and len(options) > 0:
message_parts.append("") # blank line for spacing
for i, option in enumerate(options, 1):
message_parts.append(f" {i}. {option}")
return "\n".join(message_parts)
def _handle_clarification(self, request: ToolCallRequest) -> Command:
"""Handle clarification request and return command to interrupt execution.
Args:
request: Tool call request
Returns:
Command that interrupts execution with the formatted clarification message
"""
# Extract clarification arguments
args = request.tool_call.get("args", {})
question = args.get("question", "")
logger.info("Intercepted clarification request")
logger.debug("Clarification question: %s", question)
# Format the clarification message
formatted_message = self._format_clarification_message(args)
# Get the tool call ID
tool_call_id = request.tool_call.get("id", "")
# Create a ToolMessage with the formatted question
# This will be added to the message history
tool_message = ToolMessage(
content=formatted_message,
tool_call_id=tool_call_id,
name="ask_clarification",
)
# Return a Command that:
# 1. Adds the formatted tool message
# 2. Interrupts execution by going to __end__
# Note: We don't add an extra AIMessage here - the frontend will detect
# and display ask_clarification tool messages directly
return Command(
update={"messages": [tool_message]},
goto=END,
)
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept ask_clarification tool calls and interrupt execution (sync version).
Args:
request: Tool call request
handler: Original tool execution handler
Returns:
Command that interrupts execution with the formatted clarification message
"""
# Check if this is an ask_clarification tool call
if request.tool_call.get("name") != "ask_clarification":
# Not a clarification call, execute normally
return handler(request)
return self._handle_clarification(request)
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept ask_clarification tool calls and interrupt execution (async version).
Args:
request: Tool call request
handler: Original tool execution handler (async)
Returns:
Command that interrupts execution with the formatted clarification message
"""
# Check if this is an ask_clarification tool call
if request.tool_call.get("name") != "ask_clarification":
# Not a clarification call, execute normally
return await handler(request)
return self._handle_clarification(request)

View File

@@ -0,0 +1,110 @@
"""Middleware to fix dangling tool calls in message history.
A dangling tool call occurs when an AIMessage contains tool_calls but there are
no corresponding ToolMessages in the history (e.g., due to user interruption or
request cancellation). This causes LLM errors due to incomplete message format.
This middleware intercepts the model call to detect and patch such gaps by
inserting synthetic ToolMessages with an error indicator immediately after the
AIMessage that made the tool calls, ensuring correct message ordering.
Note: Uses wrap_model_call instead of before_model to ensure patches are inserted
at the correct positions (immediately after each dangling AIMessage), not appended
to the end of the message list as before_model + add_messages reducer would do.
"""
import logging
from collections.abc import Awaitable, Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import ToolMessage
logger = logging.getLogger(__name__)
class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
Scans the message history for AIMessages whose tool_calls lack corresponding
ToolMessages, and injects synthetic error responses immediately after the
offending AIMessage so the LLM receives a well-formed conversation.
"""
def _build_patched_messages(self, messages: list) -> list | None:
"""Return a new message list with patches inserted at the correct positions.
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
a synthetic ToolMessage is inserted immediately after that AIMessage.
Returns None if no patches are needed.
"""
# Collect IDs of all existing ToolMessages
existing_tool_msg_ids: set[str] = set()
for msg in messages:
if isinstance(msg, ToolMessage):
existing_tool_msg_ids.add(msg.tool_call_id)
# Check if any patching is needed
needs_patch = False
for msg in messages:
if getattr(msg, "type", None) != "ai":
continue
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids:
needs_patch = True
break
if needs_patch:
break
if not needs_patch:
return None
# Build new list with patches inserted right after each dangling AIMessage
patched: list = []
patched_ids: set[str] = set()
patch_count = 0
for msg in messages:
patched.append(msg)
if getattr(msg, "type", None) != "ai":
continue
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
patched.append(
ToolMessage(
content="[Tool call was interrupted and did not return a result.]",
tool_call_id=tc_id,
name=tc.get("name", "unknown"),
status="error",
)
)
patched_ids.add(tc_id)
patch_count += 1
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
return patched
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
patched = self._build_patched_messages(request.messages)
if patched is not None:
request = request.override(messages=patched)
return handler(request)
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
patched = self._build_patched_messages(request.messages)
if patched is not None:
request = request.override(messages=patched)
return await handler(request)

View File

@@ -0,0 +1,60 @@
"""Middleware to filter deferred tool schemas from model binding.
When tool_search is enabled, MCP tools are registered in the DeferredToolRegistry
and passed to ToolNode for execution, but their schemas should NOT be sent to the
LLM via bind_tools (that's the whole point of deferral — saving context tokens).
This middleware intercepts wrap_model_call and removes deferred tools from
request.tools so that model.bind_tools only receives active tool schemas.
The agent discovers deferred tools at runtime via the tool_search tool.
"""
import logging
from collections.abc import Awaitable, Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
logger = logging.getLogger(__name__)
class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
"""Remove deferred tools from request.tools before model binding.
ToolNode still holds all tools (including deferred) for execution routing,
but the LLM only sees active tool schemas — deferred tools are discoverable
via tool_search at runtime.
"""
def _filter_tools(self, request: ModelRequest) -> ModelRequest:
from deerflow.tools.builtins.tool_search import get_deferred_registry
registry = get_deferred_registry()
if not registry:
return request
deferred_names = {e.name for e in registry.entries}
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
if len(active_tools) < len(request.tools):
logger.debug(f"Filtered {len(request.tools) - len(active_tools)} deferred tool schema(s) from model binding")
return request.override(tools=active_tools)
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(self._filter_tools(request))
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await handler(self._filter_tools(request))

View File

@@ -0,0 +1,275 @@
"""LLM error handling middleware with retry/backoff and user-facing fallbacks."""
from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Awaitable, Callable
from email.utils import parsedate_to_datetime
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import (
ModelCallResult,
ModelRequest,
ModelResponse,
)
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
logger = logging.getLogger(__name__)
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
_BUSY_PATTERNS = (
"server busy",
"temporarily unavailable",
"try again later",
"please retry",
"please try again",
"overloaded",
"high demand",
"rate limit",
"负载较高",
"服务繁忙",
"稍后重试",
"请稍后重试",
)
_QUOTA_PATTERNS = (
"insufficient_quota",
"quota",
"billing",
"credit",
"payment",
"余额不足",
"超出限额",
"额度不足",
"欠费",
)
_AUTH_PATTERNS = (
"authentication",
"unauthorized",
"invalid api key",
"invalid_api_key",
"permission",
"forbidden",
"access denied",
"无权",
"未授权",
)
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"""Retry transient LLM errors and surface graceful assistant messages."""
retry_max_attempts: int = 3
retry_base_delay_ms: int = 1000
retry_cap_delay_ms: int = 8000
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
detail = _extract_error_detail(exc)
lowered = detail.lower()
error_code = _extract_error_code(exc)
status_code = _extract_status_code(exc)
if _matches_any(lowered, _QUOTA_PATTERNS) or _matches_any(str(error_code).lower(), _QUOTA_PATTERNS):
return False, "quota"
if _matches_any(lowered, _AUTH_PATTERNS):
return False, "auth"
exc_name = exc.__class__.__name__
if exc_name in {
"APITimeoutError",
"APIConnectionError",
"InternalServerError",
}:
return True, "transient"
if status_code in _RETRIABLE_STATUS_CODES:
return True, "transient"
if _matches_any(lowered, _BUSY_PATTERNS):
return True, "busy"
return False, "generic"
def _build_retry_delay_ms(self, attempt: int, exc: BaseException) -> int:
retry_after = _extract_retry_after_ms(exc)
if retry_after is not None:
return retry_after
backoff = self.retry_base_delay_ms * (2 ** max(0, attempt - 1))
return min(backoff, self.retry_cap_delay_ms)
def _build_retry_message(self, attempt: int, wait_ms: int, reason: str) -> str:
seconds = max(1, round(wait_ms / 1000))
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
def _build_user_message(self, exc: BaseException, reason: str) -> str:
detail = _extract_error_detail(exc)
if reason == "quota":
return "The configured LLM provider rejected the request because the account is out of quota, billing is unavailable, or usage is restricted. Please fix the provider account and try again."
if reason == "auth":
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
if reason in {"busy", "transient"}:
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
return f"LLM request failed: {detail}"
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
try:
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer(
{
"type": "llm_retry",
"attempt": attempt,
"max_attempts": self.retry_max_attempts,
"wait_ms": wait_ms,
"reason": reason,
"message": self._build_retry_message(attempt, wait_ms, reason),
}
)
except Exception:
logger.debug("Failed to emit llm_retry event", exc_info=True)
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
attempt = 1
while True:
try:
return handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
if retriable and attempt < self.retry_max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
attempt,
self.retry_max_attempts,
wait_ms,
_extract_error_detail(exc),
)
self._emit_retry_event(attempt, wait_ms, reason)
time.sleep(wait_ms / 1000)
attempt += 1
continue
logger.warning(
"LLM call failed after %d attempt(s): %s",
attempt,
_extract_error_detail(exc),
exc_info=exc,
)
return AIMessage(content=self._build_user_message(exc, reason))
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
attempt = 1
while True:
try:
return await handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
if retriable and attempt < self.retry_max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
attempt,
self.retry_max_attempts,
wait_ms,
_extract_error_detail(exc),
)
self._emit_retry_event(attempt, wait_ms, reason)
await asyncio.sleep(wait_ms / 1000)
attempt += 1
continue
logger.warning(
"LLM call failed after %d attempt(s): %s",
attempt,
_extract_error_detail(exc),
exc_info=exc,
)
return AIMessage(content=self._build_user_message(exc, reason))
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
return any(pattern in detail for pattern in patterns)
def _extract_error_code(exc: BaseException) -> Any:
for attr in ("code", "error_code"):
value = getattr(exc, attr, None)
if value not in (None, ""):
return value
body = getattr(exc, "body", None)
if isinstance(body, dict):
error = body.get("error")
if isinstance(error, dict):
for key in ("code", "type"):
value = error.get(key)
if value not in (None, ""):
return value
return None
def _extract_status_code(exc: BaseException) -> int | None:
for attr in ("status_code", "status"):
value = getattr(exc, attr, None)
if isinstance(value, int):
return value
response = getattr(exc, "response", None)
status = getattr(response, "status_code", None)
return status if isinstance(status, int) else None
def _extract_retry_after_ms(exc: BaseException) -> int | None:
response = getattr(exc, "response", None)
headers = getattr(response, "headers", None)
if headers is None:
return None
raw = None
header_name = ""
for key in ("retry-after-ms", "Retry-After-Ms", "retry-after", "Retry-After"):
header_name = key
if hasattr(headers, "get"):
raw = headers.get(key)
if raw:
break
if not raw:
return None
try:
multiplier = 1 if "ms" in header_name.lower() else 1000
return max(0, int(float(raw) * multiplier))
except (TypeError, ValueError):
try:
target = parsedate_to_datetime(str(raw))
delta = target.timestamp() - time.time()
return max(0, int(delta * 1000))
except (TypeError, ValueError, OverflowError):
return None
def _extract_error_detail(exc: BaseException) -> str:
detail = str(exc).strip()
if detail:
return detail
message = getattr(exc, "message", None)
if isinstance(message, str) and message.strip():
return message.strip()
return exc.__class__.__name__

View File

@@ -0,0 +1,372 @@
"""Middleware to detect and break repetitive tool call loops.
P0 safety: prevents the agent from calling the same tool with the same
arguments indefinitely until the recursion limit kills the run.
Detection strategy:
1. After each model response, hash the tool calls (name + args).
2. Track recent hashes in a sliding window.
3. If the same hash appears >= warn_threshold times, inject a
"you are repeating yourself — wrap up" system message (once per hash).
4. If it appears >= hard_limit times, strip all tool_calls from the
response so the agent is forced to produce a final text answer.
"""
import hashlib
import json
import logging
import threading
from collections import OrderedDict, defaultdict
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor
_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type
_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
"""Normalize tool call args to a dict plus an optional fallback key.
Some providers serialize ``args`` as a JSON string instead of a dict.
We defensively parse those cases so loop detection does not crash while
still preserving a stable fallback key for non-dict payloads.
"""
if isinstance(raw_args, dict):
return raw_args, None
if isinstance(raw_args, str):
try:
parsed = json.loads(raw_args)
except (TypeError, ValueError, json.JSONDecodeError):
return {}, raw_args
if isinstance(parsed, dict):
return parsed, None
return {}, json.dumps(parsed, sort_keys=True, default=str)
if raw_args is None:
return {}, None
return {}, json.dumps(raw_args, sort_keys=True, default=str)
def _stable_tool_key(name: str, args: dict, fallback_key: str | None) -> str:
"""Derive a stable key from salient args without overfitting to noise."""
if name == "read_file" and fallback_key is None:
path = args.get("path") or ""
start_line = args.get("start_line")
end_line = args.get("end_line")
bucket_size = 200
try:
start_line = int(start_line) if start_line is not None else 1
except (TypeError, ValueError):
start_line = 1
try:
end_line = int(end_line) if end_line is not None else start_line
except (TypeError, ValueError):
end_line = start_line
start_line, end_line = sorted((start_line, end_line))
bucket_start = max(start_line, 1)
bucket_end = max(end_line, 1)
bucket_start = (bucket_start - 1) // bucket_size
bucket_end = (bucket_end - 1) // bucket_size
return f"{path}:{bucket_start}-{bucket_end}"
# write_file / str_replace are content-sensitive: same path may be updated
# with different payloads during iteration. Using only salient fields (path)
# can collapse distinct calls, so we hash full args to reduce false positives.
if name in {"write_file", "str_replace"}:
if fallback_key is not None:
return fallback_key
return json.dumps(args, sort_keys=True, default=str)
salient_fields = ("path", "url", "query", "command", "pattern", "glob", "cmd")
stable_args = {field: args[field] for field in salient_fields if args.get(field) is not None}
if stable_args:
return json.dumps(stable_args, sort_keys=True, default=str)
if fallback_key is not None:
return fallback_key
return json.dumps(args, sort_keys=True, default=str)
def _hash_tool_calls(tool_calls: list[dict]) -> str:
"""Deterministic hash of a set of tool calls (name + stable key).
This is intended to be order-independent: the same multiset of tool calls
should always produce the same hash, regardless of their input order.
"""
# Normalize each tool call to a stable (name, key) structure.
normalized: list[str] = []
for tc in tool_calls:
name = tc.get("name", "")
args, fallback_key = _normalize_tool_call_args(tc.get("args", {}))
key = _stable_tool_key(name, args, fallback_key)
normalized.append(f"{name}:{key}")
# Sort so permutations of the same multiset of calls yield the same ordering.
normalized.sort()
blob = json.dumps(normalized, sort_keys=True, default=str)
return hashlib.md5(blob.encode()).hexdigest()[:12]
_WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
_TOOL_FREQ_WARNING_MSG = (
"[LOOP DETECTED] You have called {tool_name} {count} times without producing a final answer. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
)
_HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far."
_TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times — exceeded the per-tool safety limit. Producing final answer with results collected so far."
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Detects and breaks repetitive tool call loops.
Args:
warn_threshold: Number of identical tool call sets before injecting
a warning message. Default: 3.
hard_limit: Number of identical tool call sets before stripping
tool_calls entirely. Default: 5.
window_size: Size of the sliding window for tracking calls.
Default: 20.
max_tracked_threads: Maximum number of threads to track before
evicting the least recently used. Default: 100.
tool_freq_warn: Number of calls to the same tool *type* (regardless
of arguments) before injecting a frequency warning. Catches
cross-file read loops that hash-based detection misses.
Default: 30.
tool_freq_hard_limit: Number of calls to the same tool type before
forcing a stop. Default: 50.
"""
def __init__(
self,
warn_threshold: int = _DEFAULT_WARN_THRESHOLD,
hard_limit: int = _DEFAULT_HARD_LIMIT,
window_size: int = _DEFAULT_WINDOW_SIZE,
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
):
super().__init__()
self.warn_threshold = warn_threshold
self.hard_limit = hard_limit
self.window_size = window_size
self.max_tracked_threads = max_tracked_threads
self.tool_freq_warn = tool_freq_warn
self.tool_freq_hard_limit = tool_freq_hard_limit
self._lock = threading.Lock()
# Per-thread tracking using OrderedDict for LRU eviction
self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set)
# Per-thread, per-tool-type cumulative call counts
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id:
return thread_id
return "default"
def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit.
Must be called while holding self._lock.
"""
while len(self._history) > self.max_tracked_threads:
evicted_id, _ = self._history.popitem(last=False)
self._warned.pop(evicted_id, None)
self._tool_freq.pop(evicted_id, None)
self._tool_freq_warned.pop(evicted_id, None)
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
"""Track tool calls and check for loops.
Two detection layers:
1. **Hash-based** (existing): catches identical tool call sets.
2. **Frequency-based** (new): catches the same *tool type* being
called many times with varying arguments (e.g. ``read_file``
on 40 different files).
Returns:
(warning_message_or_none, should_hard_stop)
"""
messages = state.get("messages", [])
if not messages:
return None, False
last_msg = messages[-1]
if getattr(last_msg, "type", None) != "ai":
return None, False
tool_calls = getattr(last_msg, "tool_calls", None)
if not tool_calls:
return None, False
thread_id = self._get_thread_id(runtime)
call_hash = _hash_tool_calls(tool_calls)
with self._lock:
# Touch / create entry (move to end for LRU)
if thread_id in self._history:
self._history.move_to_end(thread_id)
else:
self._history[thread_id] = []
self._evict_if_needed()
history = self._history[thread_id]
history.append(call_hash)
if len(history) > self.window_size:
history[:] = history[-self.window_size :]
count = history.count(call_hash)
tool_names = [tc.get("name", "?") for tc in tool_calls]
# --- Layer 1: hash-based (identical call sets) ---
if count >= self.hard_limit:
logger.error(
"Loop hard limit reached — forcing stop",
extra={
"thread_id": thread_id,
"call_hash": call_hash,
"count": count,
"tools": tool_names,
},
)
return _HARD_STOP_MSG, True
if count >= self.warn_threshold:
warned = self._warned[thread_id]
if call_hash not in warned:
warned.add(call_hash)
logger.warning(
"Repetitive tool calls detected — injecting warning",
extra={
"thread_id": thread_id,
"call_hash": call_hash,
"count": count,
"tools": tool_names,
},
)
return _WARNING_MSG, False
# --- Layer 2: per-tool-type frequency ---
freq = self._tool_freq[thread_id]
for tc in tool_calls:
name = tc.get("name", "")
if not name:
continue
freq[name] += 1
tc_count = freq[name]
if tc_count >= self.tool_freq_hard_limit:
logger.error(
"Tool frequency hard limit reached — forcing stop",
extra={
"thread_id": thread_id,
"tool_name": name,
"count": tc_count,
},
)
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
if tc_count >= self.tool_freq_warn:
warned = self._tool_freq_warned[thread_id]
if name not in warned:
warned.add(name)
logger.warning(
"Tool frequency warning — too many calls to same tool type",
extra={
"thread_id": thread_id,
"tool_name": name,
"count": tc_count,
},
)
return _TOOL_FREQ_WARNING_MSG.format(tool_name=name, count=tc_count), False
return None, False
@staticmethod
def _append_text(content: str | list | None, text: str) -> str | list:
"""Append *text* to AIMessage content, handling str, list, and None.
When content is a list of content blocks (e.g. Anthropic thinking mode),
we append a new ``{"type": "text", ...}`` block instead of concatenating
a string to a list, which would raise ``TypeError``.
"""
if content is None:
return text
if isinstance(content, list):
return [*content, {"type": "text", "text": f"\n\n{text}"}]
if isinstance(content, str):
return content + f"\n\n{text}"
# Fallback: coerce unexpected types to str to avoid TypeError
return str(content) + f"\n\n{text}"
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
warning, hard_stop = self._track_and_check(state, runtime)
if hard_stop:
# Strip tool_calls from the last AIMessage to force text output
messages = state.get("messages", [])
last_msg = messages[-1]
stripped_msg = last_msg.model_copy(
update={
"tool_calls": [],
"content": self._append_text(last_msg.content, warning),
}
)
return {"messages": [stripped_msg]}
if warning:
# Inject as HumanMessage instead of SystemMessage to avoid
# Anthropic's "multiple non-consecutive system messages" error.
# Anthropic models require system messages only at the start of
# the conversation; injecting one mid-conversation crashes
# langchain_anthropic's _format_messages(). HumanMessage works
# with all providers. See #1299.
return {"messages": [HumanMessage(content=warning)]}
return None
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
def reset(self, thread_id: str | None = None) -> None:
"""Clear tracking state. If thread_id given, clear only that thread."""
with self._lock:
if thread_id:
self._history.pop(thread_id, None)
self._warned.pop(thread_id, None)
self._tool_freq.pop(thread_id, None)
self._tool_freq_warned.pop(thread_id, None)
else:
self._history.clear()
self._warned.clear()
self._tool_freq.clear()
self._tool_freq_warned.clear()

View File

@@ -0,0 +1,248 @@
"""Middleware for memory mechanism."""
import logging
import re
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config
logger = logging.getLogger(__name__)
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
_CORRECTION_PATTERNS = (
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
re.compile(r"\btry again\b", re.IGNORECASE),
re.compile(r"\bredo\b", re.IGNORECASE),
re.compile(r"不对"),
re.compile(r"你理解错了"),
re.compile(r"你理解有误"),
re.compile(r"重试"),
re.compile(r"重新来"),
re.compile(r"换一种"),
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
pass
def _extract_message_text(message: Any) -> str:
"""Extract plain text from message content for filtering and signal detection."""
content = getattr(message, "content", "")
if isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict):
text_val = part.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
return " ".join(text_parts)
return str(content)
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Filter messages to keep only user inputs and final assistant responses.
This filters out:
- Tool messages (intermediate tool call results)
- AI messages with tool_calls (intermediate steps, not final responses)
- The <uploaded_files> block injected by UploadsMiddleware into human messages
(file paths are session-scoped and must not persist in long-term memory).
The user's actual question is preserved; only turns whose content is entirely
the upload block (nothing remains after stripping) are dropped along with
their paired assistant response.
Only keeps:
- Human messages (with the ephemeral upload block removed)
- AI messages without tool_calls (final assistant responses), unless the
paired human turn was upload-only and had no real user text.
Args:
messages: List of all conversation messages.
Returns:
Filtered list containing only user inputs and final assistant responses.
"""
filtered = []
skip_next_ai = False
for msg in messages:
msg_type = getattr(msg, "type", None)
if msg_type == "human":
content_str = _extract_message_text(msg)
if "<uploaded_files>" in content_str:
# Strip the ephemeral upload block; keep the user's real question.
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
if not stripped:
# Nothing left — the entire turn was upload bookkeeping;
# skip it and the paired assistant response.
skip_next_ai = True
continue
# Rebuild the message with cleaned content so the user's question
# is still available for memory summarisation.
from copy import copy
clean_msg = copy(msg)
clean_msg.content = stripped
filtered.append(clean_msg)
skip_next_ai = False
else:
filtered.append(msg)
skip_next_ai = False
elif msg_type == "ai":
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
if skip_next_ai:
skip_next_ai = False
continue
filtered.append(msg)
# Skip tool messages and AI messages with tool_calls
return filtered
def detect_correction(messages: list[Any]) -> bool:
"""Detect explicit user corrections in recent conversation turns.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale corrections from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
return True
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns.
Complements detect_correction() by identifying when the user confirms the
agent's approach was correct. This allows the memory system to record what
worked well, not just what went wrong.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale signals from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution.
This middleware:
1. After each agent execution, queues the conversation for memory update
2. Only includes user inputs and final assistant responses (ignores tool calls)
3. The queue uses debouncing to batch multiple updates together
4. Memory is updated asynchronously via LLM summarization
"""
state_schema = MemoryMiddlewareState
def __init__(self, agent_name: str | None = None):
"""Initialize the MemoryMiddleware.
Args:
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
"""
super().__init__()
self._agent_name = agent_name
@override
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
"""Queue conversation for memory update after agent completes.
Args:
state: The current agent state.
runtime: The runtime context.
Returns:
None (no state changes needed from this middleware).
"""
config = get_memory_config()
if not config.enabled:
return None
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
config_data = get_config()
thread_id = config_data.get("configurable", {}).get("thread_id")
if not thread_id:
logger.debug("No thread_id in context, skipping memory update")
return None
# Get messages from state
messages = state.get("messages", [])
if not messages:
logger.debug("No messages in state, skipping memory update")
return None
# Filter to only keep user inputs and final assistant responses
filtered_messages = _filter_messages_for_memory(messages)
# Only queue if there's meaningful conversation
# At minimum need one user message and one assistant response
user_messages = [m for m in filtered_messages if getattr(m, "type", None) == "human"]
assistant_messages = [m for m in filtered_messages if getattr(m, "type", None) == "ai"]
if not user_messages or not assistant_messages:
return None
# Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue()
queue.add(
thread_id=thread_id,
messages=filtered_messages,
agent_name=self._agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
return None

View File

@@ -0,0 +1,363 @@
"""SandboxAuditMiddleware - bash command security auditing."""
import json
import logging
import re
import shlex
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from typing import override
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from deerflow.agents.thread_state import ThreadState
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Command classification rules
# ---------------------------------------------------------------------------
# Each pattern is compiled once at import time.
_HIGH_RISK_PATTERNS: list[re.Pattern[str]] = [
# --- original rules (retained) ---
re.compile(r"rm\s+-[^\s]*r[^\s]*\s+(/\*?|~/?\*?|/home\b|/root\b)\s*$"),
re.compile(r"dd\s+if="),
re.compile(r"mkfs"),
re.compile(r"cat\s+/etc/shadow"),
re.compile(r">+\s*/etc/"),
# --- pipe to sh/bash (generalised, replaces old curl|sh rule) ---
re.compile(r"\|\s*(ba)?sh\b"),
# --- command substitution (targeted only dangerous executables) ---
re.compile(r"[`$]\(?\s*(curl|wget|bash|sh|python|ruby|perl|base64)"),
# --- base64 decode piped to execution ---
re.compile(r"base64\s+.*-d.*\|"),
# --- overwrite system binaries ---
re.compile(r">+\s*(/usr/bin/|/bin/|/sbin/)"),
# --- overwrite shell startup files ---
re.compile(r">+\s*~/?\.(bashrc|profile|zshrc|bash_profile)"),
# --- process environment leakage ---
re.compile(r"/proc/[^/]+/environ"),
# --- dynamic linker hijack (one-step escalation) ---
re.compile(r"\b(LD_PRELOAD|LD_LIBRARY_PATH)\s*="),
# --- bash built-in networking (bypasses tool allowlists) ---
re.compile(r"/dev/tcp/"),
# --- fork bomb ---
re.compile(r"\S+\(\)\s*\{[^}]*\|\s*\S+\s*&"), # :(){ :|:& };:
re.compile(r"while\s+true.*&\s*done"), # while true; do bash & done
]
_MEDIUM_RISK_PATTERNS: list[re.Pattern[str]] = [
re.compile(r"chmod\s+777"),
re.compile(r"pip3?\s+install"),
re.compile(r"apt(-get)?\s+install"),
# sudo/su: no-op under Docker root; warn so LLM is aware
re.compile(r"\b(sudo|su)\b"),
# PATH modification: long attack chain, warn rather than block
re.compile(r"\bPATH\s*="),
]
def _split_compound_command(command: str) -> list[str]:
"""Split a compound command into sub-commands (quote-aware).
Scans the raw command string so unquoted shell control operators are
recognised even when they are not surrounded by whitespace
(e.g. ``safe;rm -rf /`` or ``rm -rf /&&echo ok``). Operators inside
quotes are ignored. If the command ends with an unclosed quote or a
dangling escape, return the whole command unchanged (fail-closed —
safer to classify the unsplit string than silently drop parts).
"""
parts: list[str] = []
current: list[str] = []
in_single_quote = False
in_double_quote = False
escaping = False
index = 0
while index < len(command):
char = command[index]
if escaping:
current.append(char)
escaping = False
index += 1
continue
if char == "\\" and not in_single_quote:
current.append(char)
escaping = True
index += 1
continue
if char == "'" and not in_double_quote:
in_single_quote = not in_single_quote
current.append(char)
index += 1
continue
if char == '"' and not in_single_quote:
in_double_quote = not in_double_quote
current.append(char)
index += 1
continue
if not in_single_quote and not in_double_quote:
if command.startswith("&&", index) or command.startswith("||", index):
part = "".join(current).strip()
if part:
parts.append(part)
current = []
index += 2
continue
if char == ";":
part = "".join(current).strip()
if part:
parts.append(part)
current = []
index += 1
continue
current.append(char)
index += 1
# Unclosed quote or dangling escape → fail-closed, return whole command
if in_single_quote or in_double_quote or escaping:
return [command]
part = "".join(current).strip()
if part:
parts.append(part)
return parts if parts else [command]
def _classify_single_command(command: str) -> str:
"""Classify a single (non-compound) command. Return 'block', 'warn', or 'pass'."""
normalized = " ".join(command.split())
for pattern in _HIGH_RISK_PATTERNS:
if pattern.search(normalized):
return "block"
# Also try shlex-parsed tokens for high-risk detection
try:
tokens = shlex.split(command)
joined = " ".join(tokens)
for pattern in _HIGH_RISK_PATTERNS:
if pattern.search(joined):
return "block"
except ValueError:
# shlex.split fails on unclosed quotes — treat as suspicious
return "block"
for pattern in _MEDIUM_RISK_PATTERNS:
if pattern.search(normalized):
return "warn"
return "pass"
def _classify_command(command: str) -> str:
"""Return 'block', 'warn', or 'pass'.
Strategy:
1. First scan the *whole* raw command against high-risk patterns. This
catches structural attacks like ``while true; do bash & done`` or
``:(){ :|:& };:`` that span multiple shell statements — splitting them
on ``;`` would destroy the pattern context.
2. Then split compound commands (e.g. ``cmd1 && cmd2 ; cmd3``) and
classify each sub-command independently. The most severe verdict wins.
"""
# Pass 1: whole-command high-risk scan (catches multi-statement patterns)
normalized = " ".join(command.split())
for pattern in _HIGH_RISK_PATTERNS:
if pattern.search(normalized):
return "block"
# Pass 2: per-sub-command classification
sub_commands = _split_compound_command(command)
worst = "pass"
for sub in sub_commands:
verdict = _classify_single_command(sub)
if verdict == "block":
return "block" # short-circuit: can't get worse
if verdict == "warn":
worst = "warn"
return worst
# ---------------------------------------------------------------------------
# Middleware
# ---------------------------------------------------------------------------
class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
"""Bash command security auditing middleware.
For every ``bash`` tool call:
1. **Command classification**: regex + shlex analysis grades commands as
high-risk (block), medium-risk (warn), or safe (pass).
2. **Audit log**: every bash call is recorded as a structured JSON entry
via the standard logger (visible in langgraph.log).
High-risk commands (e.g. ``rm -rf /``, ``curl url | bash``) are blocked:
the handler is not called and an error ``ToolMessage`` is returned so the
agent loop can continue gracefully.
Medium-risk commands (e.g. ``pip install``, ``chmod 777``) are executed
normally; a warning is appended to the tool result so the LLM is aware.
"""
state_schema = ThreadState
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _get_thread_id(self, request: ToolCallRequest) -> str | None:
runtime = request.runtime # ToolRuntime; may be None-like in tests
if runtime is None:
return None
ctx = getattr(runtime, "context", None) or {}
thread_id = ctx.get("thread_id") if isinstance(ctx, dict) else None
if thread_id is None:
cfg = getattr(runtime, "config", None) or {}
thread_id = cfg.get("configurable", {}).get("thread_id")
return thread_id
_AUDIT_COMMAND_LIMIT = 200
def _write_audit(self, thread_id: str | None, command: str, verdict: str, *, truncate: bool = False) -> None:
audited_command = command
if truncate and len(command) > self._AUDIT_COMMAND_LIMIT:
audited_command = f"{command[: self._AUDIT_COMMAND_LIMIT]}... ({len(command)} chars)"
record = {
"timestamp": datetime.now(UTC).isoformat(),
"thread_id": thread_id or "unknown",
"command": audited_command,
"verdict": verdict,
}
logger.info("[SandboxAudit] %s", json.dumps(record, ensure_ascii=False))
def _build_block_message(self, request: ToolCallRequest, reason: str) -> ToolMessage:
tool_call_id = str(request.tool_call.get("id") or "missing_id")
return ToolMessage(
content=f"Command blocked: {reason}. Please use a safer alternative approach.",
tool_call_id=tool_call_id,
name="bash",
status="error",
)
def _append_warn_to_result(self, result: ToolMessage | Command, command: str) -> ToolMessage | Command:
"""Append a warning note to the tool result for medium-risk commands."""
if not isinstance(result, ToolMessage):
return result
warning = f"\n\n⚠️ Warning: `{command}` is a medium-risk command that may modify the runtime environment."
if isinstance(result.content, list):
new_content = list(result.content) + [{"type": "text", "text": warning}]
else:
new_content = str(result.content) + warning
return ToolMessage(
content=new_content,
tool_call_id=result.tool_call_id,
name=result.name,
status=result.status,
)
# ------------------------------------------------------------------
# Input sanitisation
# ------------------------------------------------------------------
# Normal bash commands rarely exceed a few hundred characters. 10 000 is
# well above any legitimate use case yet a tiny fraction of Linux ARG_MAX.
# Anything longer is almost certainly a payload injection or base64-encoded
# attack string.
_MAX_COMMAND_LENGTH = 10_000
def _validate_input(self, command: str) -> str | None:
"""Return ``None`` if *command* is acceptable, else a rejection reason."""
if not command.strip():
return "empty command"
if len(command) > self._MAX_COMMAND_LENGTH:
return "command too long"
if "\x00" in command:
return "null byte detected"
return None
# ------------------------------------------------------------------
# Core logic (shared between sync and async paths)
# ------------------------------------------------------------------
def _pre_process(self, request: ToolCallRequest) -> tuple[str, str | None, str, str | None]:
"""
Returns (command, thread_id, verdict, reject_reason).
verdict is 'block', 'warn', or 'pass'.
reject_reason is non-None only for input sanitisation rejections.
"""
args = request.tool_call.get("args", {})
raw_command = args.get("command")
command = raw_command if isinstance(raw_command, str) else ""
thread_id = self._get_thread_id(request)
# ① input sanitisation — reject malformed input before regex analysis
reject_reason = self._validate_input(command)
if reject_reason:
self._write_audit(thread_id, command, "block", truncate=True)
logger.warning("[SandboxAudit] INVALID INPUT thread=%s reason=%s", thread_id, reject_reason)
return command, thread_id, "block", reject_reason
# ② classify command
verdict = _classify_command(command)
# ③ audit log
self._write_audit(thread_id, command, verdict)
if verdict == "block":
logger.warning("[SandboxAudit] BLOCKED thread=%s cmd=%r", thread_id, command)
elif verdict == "warn":
logger.warning("[SandboxAudit] WARN (medium-risk) thread=%s cmd=%r", thread_id, command)
return command, thread_id, verdict, None
# ------------------------------------------------------------------
# wrap_tool_call hooks
# ------------------------------------------------------------------
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
if request.tool_call.get("name") != "bash":
return handler(request)
command, _, verdict, reject_reason = self._pre_process(request)
if verdict == "block":
reason = reject_reason or "security violation detected"
return self._build_block_message(request, reason)
result = handler(request)
if verdict == "warn":
result = self._append_warn_to_result(result, command)
return result
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
if request.tool_call.get("name") != "bash":
return await handler(request)
command, _, verdict, reject_reason = self._pre_process(request)
if verdict == "block":
reason = reject_reason or "security violation detected"
return self._build_block_message(request, reason)
result = await handler(request)
if verdict == "warn":
result = self._append_warn_to_result(result, command)
return result

View File

@@ -0,0 +1,75 @@
"""Middleware to enforce maximum concurrent subagent tool calls per model response."""
import logging
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
logger = logging.getLogger(__name__)
# Valid range for max_concurrent_subagents
MIN_SUBAGENT_LIMIT = 2
MAX_SUBAGENT_LIMIT = 4
def _clamp_subagent_limit(value: int) -> int:
"""Clamp subagent limit to valid range [2, 4]."""
return max(MIN_SUBAGENT_LIMIT, min(MAX_SUBAGENT_LIMIT, value))
class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
"""Truncates excess 'task' tool calls from a single model response.
When an LLM generates more than max_concurrent parallel task tool calls
in one response, this middleware keeps only the first max_concurrent and
discards the rest. This is more reliable than prompt-based limits.
Args:
max_concurrent: Maximum number of concurrent subagent calls allowed.
Defaults to MAX_CONCURRENT_SUBAGENTS (3). Clamped to [2, 4].
"""
def __init__(self, max_concurrent: int = MAX_CONCURRENT_SUBAGENTS):
super().__init__()
self.max_concurrent = _clamp_subagent_limit(max_concurrent)
def _truncate_task_calls(self, state: AgentState) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
last_msg = messages[-1]
if getattr(last_msg, "type", None) != "ai":
return None
tool_calls = getattr(last_msg, "tool_calls", None)
if not tool_calls:
return None
# Count task tool calls
task_indices = [i for i, tc in enumerate(tool_calls) if tc.get("name") == "task"]
if len(task_indices) <= self.max_concurrent:
return None
# Build set of indices to drop (excess task calls beyond the limit)
indices_to_drop = set(task_indices[self.max_concurrent :])
truncated_tool_calls = [tc for i, tc in enumerate(tool_calls) if i not in indices_to_drop]
dropped_count = len(indices_to_drop)
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
return {"messages": [updated_msg]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._truncate_task_calls(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._truncate_task_calls(state)

View File

@@ -0,0 +1,99 @@
import logging
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.paths import Paths, get_paths
logger = logging.getLogger(__name__)
class ThreadDataMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
thread_data: NotRequired[ThreadDataState | None]
class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
"""Create thread data directories for each thread execution.
Creates the following directory structure:
- {base_dir}/threads/{thread_id}/user-data/workspace
- {base_dir}/threads/{thread_id}/user-data/uploads
- {base_dir}/threads/{thread_id}/user-data/outputs
Lifecycle Management:
- With lazy_init=True (default): Only compute paths, directories created on-demand
- With lazy_init=False: Eagerly create directories in before_agent()
"""
state_schema = ThreadDataMiddlewareState
def __init__(self, base_dir: str | None = None, lazy_init: bool = True):
"""Initialize the middleware.
Args:
base_dir: Base directory for thread data. Defaults to Paths resolution.
lazy_init: If True, defer directory creation until needed.
If False, create directories eagerly in before_agent().
Default is True for optimal performance.
"""
super().__init__()
self._paths = Paths(base_dir) if base_dir else get_paths()
self._lazy_init = lazy_init
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
"""Get the paths for a thread's data directories.
Args:
thread_id: The thread ID.
Returns:
Dictionary with workspace_path, uploads_path, and outputs_path.
"""
return {
"workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
}
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
"""Create the thread data directories.
Args:
thread_id: The thread ID.
Returns:
Dictionary with the created directory paths.
"""
self._paths.ensure_thread_dirs(thread_id)
return self._get_thread_paths(thread_id)
@override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
context = runtime.context or {}
thread_id = context.get("thread_id")
if thread_id is None:
config = get_config()
thread_id = config.get("configurable", {}).get("thread_id")
if thread_id is None:
raise ValueError("Thread ID is required in runtime context or config.configurable")
if self._lazy_init:
# Lazy initialization: only compute paths, don't create directories
paths = self._get_thread_paths(thread_id)
else:
# Eager initialization: create directories immediately
paths = self._create_thread_directories(thread_id)
logger.debug("Created thread data directories for thread %s", thread_id)
return {
"thread_data": {
**paths,
}
}

View File

@@ -0,0 +1,138 @@
"""Middleware for automatic thread title generation."""
import logging
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
class TitleMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
title: NotRequired[str | None]
class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
"""Automatically generate a title for the thread after the first user message."""
state_schema = TitleMiddlewareState
def _normalize_content(self, content: object) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [self._normalize_content(item) for item in content]
return "\n".join(part for part in parts if part)
if isinstance(content, dict):
text_value = content.get("text")
if isinstance(text_value, str):
return text_value
nested_content = content.get("content")
if nested_content is not None:
return self._normalize_content(nested_content)
return ""
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
"""Check if we should generate a title for this thread."""
config = get_title_config()
if not config.enabled:
return False
# Check if thread already has a title in state
if state.get("title"):
return False
# Check if this is the first turn (has at least one user message and one assistant response)
messages = state.get("messages", [])
if len(messages) < 2:
return False
# Count user and assistant messages
user_messages = [m for m in messages if m.type == "human"]
assistant_messages = [m for m in messages if m.type == "ai"]
# Generate title after first complete exchange
return len(user_messages) == 1 and len(assistant_messages) >= 1
def _build_title_prompt(self, state: TitleMiddlewareState) -> tuple[str, str]:
"""Extract user/assistant messages and build the title prompt.
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
"""
config = get_title_config()
messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
user_msg = self._normalize_content(user_msg_content)
assistant_msg = self._normalize_content(assistant_msg_content)
prompt = config.prompt_template.format(
max_words=config.max_words,
user_msg=user_msg[:500],
assistant_msg=assistant_msg[:500],
)
return prompt, user_msg
def _parse_title(self, content: object) -> str:
"""Normalize model output into a clean title string."""
config = get_title_config()
title_content = self._normalize_content(content)
title = title_content.strip().strip('"').strip("'")
return title[: config.max_chars] if len(title) > config.max_chars else title
def _fallback_title(self, user_msg: str) -> str:
config = get_title_config()
fallback_chars = min(config.max_chars, 50)
if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation"
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Generate a local fallback title without blocking on an LLM call."""
if not self._should_generate_title(state):
return None
_, user_msg = self._build_title_prompt(state)
return {"title": self._fallback_title(user_msg)}
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Generate a title asynchronously and fall back locally on failure."""
if not self._should_generate_title(state):
return None
config = get_title_config()
prompt, user_msg = self._build_title_prompt(state)
try:
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
else:
model = create_chat_model(thinking_enabled=False)
response = await model.ainvoke(prompt)
title = self._parse_title(response.content)
if title:
return {"title": title}
except Exception:
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
return {"title": self._fallback_title(user_msg)}
@override
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
return self._generate_title_result(state)
@override
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
return await self._agenerate_title_result(state)

View File

@@ -0,0 +1,100 @@
"""Middleware that extends TodoListMiddleware with context-loss detection.
When the message history is truncated (e.g., by SummarizationMiddleware), the
original `write_todos` tool call and its ToolMessage can be scrolled out of the
active context window. This middleware detects that situation and injects a
reminder message so the model still knows about the outstanding todo list.
"""
from __future__ import annotations
from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime
def _todos_in_messages(messages: list[Any]) -> bool:
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
for msg in messages:
if isinstance(msg, AIMessage) and msg.tool_calls:
for tc in msg.tool_calls:
if tc.get("name") == "write_todos":
return True
return False
def _reminder_in_messages(messages: list[Any]) -> bool:
"""Return True if a todo_reminder HumanMessage is already present in *messages*."""
for msg in messages:
if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_reminder":
return True
return False
def _format_todos(todos: list[Todo]) -> str:
"""Format a list of Todo items into a human-readable string."""
lines: list[str] = []
for todo in todos:
status = todo.get("status", "pending")
content = todo.get("content", "")
lines.append(f"- [{status}] {content}")
return "\n".join(lines)
class TodoMiddleware(TodoListMiddleware):
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
When the original `write_todos` tool call has been truncated from the message
history (e.g., after summarization), the model loses awareness of the current
todo list. This middleware detects that gap in `before_model` / `abefore_model`
and injects a reminder message so the model can continue tracking progress.
"""
@override
def before_model(
self,
state: PlanningState,
runtime: Runtime, # noqa: ARG002
) -> dict[str, Any] | None:
"""Inject a todo-list reminder when write_todos has left the context window."""
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
if not todos:
return None
messages = state.get("messages") or []
if _todos_in_messages(messages):
# write_todos is still visible in context — nothing to do.
return None
if _reminder_in_messages(messages):
# A reminder was already injected and hasn't been truncated yet.
return None
# The todo list exists in state but the original write_todos call is gone.
# Inject a reminder as a HumanMessage so the model stays aware.
formatted = _format_todos(todos)
reminder = HumanMessage(
name="todo_reminder",
content=(
"<system_reminder>\n"
"Your todo list from earlier is no longer visible in the current context window, "
"but it is still active. Here is the current state:\n\n"
f"{formatted}\n\n"
"Continue tracking and updating this todo list as you work. "
"Call `write_todos` whenever the status of any item changes.\n"
"</system_reminder>"
),
)
return {"messages": [reminder]}
@override
async def abefore_model(
self,
state: PlanningState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async version of before_model."""
return self.before_model(state, runtime)

View File

@@ -0,0 +1,37 @@
"""Middleware for logging LLM token usage."""
import logging
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
class TokenUsageMiddleware(AgentMiddleware):
"""Logs token usage from model response usage_metadata."""
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state)
def _log_usage(self, state: AgentState) -> None:
messages = state.get("messages", [])
if not messages:
return None
last = messages[-1]
usage = getattr(last, "usage_metadata", None)
if usage:
logger.info(
"LLM token usage: input=%s output=%s total=%s",
usage.get("input_tokens", "?"),
usage.get("output_tokens", "?"),
usage.get("total_tokens", "?"),
)
return None

View File

@@ -0,0 +1,143 @@
"""Tool error handling middleware and shared runtime middleware builders."""
import logging
from collections.abc import Awaitable, Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.errors import GraphBubbleUp
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__)
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"""Convert tool exceptions into error ToolMessages so the run can continue."""
def _build_error_message(self, request: ToolCallRequest, exc: Exception) -> ToolMessage:
tool_name = str(request.tool_call.get("name") or "unknown_tool")
tool_call_id = str(request.tool_call.get("id") or _MISSING_TOOL_CALL_ID)
detail = str(exc).strip() or exc.__class__.__name__
if len(detail) > 500:
detail = detail[:497] + "..."
content = f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. Continue with available context, or choose an alternative tool."
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
try:
return handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
return self._build_error_message(request, exc)
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
try:
return await handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
return self._build_error_message(request, exc)
def _build_runtime_middlewares(
*,
include_uploads: bool,
include_dangling_tool_call_patch: bool,
lazy_init: bool = True,
) -> list[AgentMiddleware]:
"""Build shared base middlewares for agent execution."""
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from deerflow.sandbox.middleware import SandboxMiddleware
middlewares: list[AgentMiddleware] = [
ThreadDataMiddleware(lazy_init=lazy_init),
SandboxMiddleware(lazy_init=lazy_init),
]
if include_uploads:
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
middlewares.insert(1, UploadsMiddleware())
if include_dangling_tool_call_patch:
from deerflow.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
middlewares.append(DanglingToolCallMiddleware())
middlewares.append(LLMErrorHandlingMiddleware())
# Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config
guardrails_config = get_guardrails_config()
if guardrails_config.enabled and guardrails_config.provider:
import inspect
from deerflow.guardrails.middleware import GuardrailMiddleware
from deerflow.reflection import resolve_variable
provider_cls = resolve_variable(guardrails_config.provider.use)
provider_kwargs = dict(guardrails_config.provider.config) if guardrails_config.provider.config else {}
# Pass framework hint if the provider accepts it (e.g. for config discovery).
# Built-in providers like AllowlistProvider don't need it, so only inject
# when the constructor accepts 'framework' or '**kwargs'.
if "framework" not in provider_kwargs:
try:
sig = inspect.signature(provider_cls.__init__)
if "framework" in sig.parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):
provider_kwargs["framework"] = "deerflow"
except (ValueError, TypeError):
pass
provider = provider_cls(**provider_kwargs)
middlewares.append(GuardrailMiddleware(provider, fail_closed=guardrails_config.fail_closed, passport=guardrails_config.passport))
from deerflow.agents.middlewares.sandbox_audit_middleware import SandboxAuditMiddleware
middlewares.append(SandboxAuditMiddleware())
middlewares.append(ToolErrorHandlingMiddleware())
return middlewares
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
return _build_runtime_middlewares(
include_uploads=True,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,
)
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
return _build_runtime_middlewares(
include_uploads=False,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,
)

View File

@@ -0,0 +1,293 @@
"""Middleware to inject uploaded files information into agent context."""
import logging
from pathlib import Path
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.config.paths import Paths, get_paths
from deerflow.utils.file_conversion import extract_outline
logger = logging.getLogger(__name__)
_OUTLINE_PREVIEW_LINES = 5
def _extract_outline_for_file(file_path: Path) -> tuple[list[dict], list[str]]:
"""Return the document outline and fallback preview for *file_path*.
Looks for a sibling ``<stem>.md`` file produced by the upload conversion
pipeline.
Returns:
(outline, preview) where:
- outline: list of ``{title, line}`` dicts (plus optional sentinel).
Empty when no headings are found or no .md exists.
- preview: first few non-empty lines of the .md, used as a content
anchor when outline is empty so the agent has some context.
Empty when outline is non-empty (no fallback needed).
"""
md_path = file_path.with_suffix(".md")
if not md_path.is_file():
return [], []
outline = extract_outline(md_path)
if outline:
logger.debug("Extracted %d outline entries from %s", len(outline), file_path.name)
return outline, []
# outline is empty — read the first few non-empty lines as a content preview
preview: list[str] = []
try:
with md_path.open(encoding="utf-8") as f:
for line in f:
stripped = line.strip()
if stripped:
preview.append(stripped)
if len(preview) >= _OUTLINE_PREVIEW_LINES:
break
except Exception:
logger.debug("Failed to read preview lines from %s", md_path, exc_info=True)
return [], preview
class UploadsMiddlewareState(AgentState):
"""State schema for uploads middleware."""
uploaded_files: NotRequired[list[dict] | None]
class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
"""Middleware to inject uploaded files information into the agent context.
Reads file metadata from the current message's additional_kwargs.files
(set by the frontend after upload) and prepends an <uploaded_files> block
to the last human message so the model knows which files are available.
"""
state_schema = UploadsMiddlewareState
def __init__(self, base_dir: str | None = None):
"""Initialize the middleware.
Args:
base_dir: Base directory for thread data. Defaults to Paths resolution.
"""
super().__init__()
self._paths = Paths(base_dir) if base_dir else get_paths()
def _format_file_entry(self, file: dict, lines: list[str]) -> None:
"""Append a single file entry (name, size, path, optional outline) to lines."""
size_kb = file["size"] / 1024
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
lines.append(f"- {file['filename']} ({size_str})")
lines.append(f" Path: {file['path']}")
outline = file.get("outline") or []
if outline:
truncated = outline[-1].get("truncated", False)
visible = [e for e in outline if not e.get("truncated")]
lines.append(" Document outline (use `read_file` with line ranges to read sections):")
for entry in visible:
lines.append(f" L{entry['line']}: {entry['title']}")
if truncated:
lines.append(f" ... (showing first {len(visible)} headings; use `read_file` to explore further)")
else:
preview = file.get("outline_preview") or []
if preview:
lines.append(" No structural headings detected. Document begins with:")
for text in preview:
lines.append(f" > {text}")
lines.append(" Use `grep` to search for keywords (e.g. `grep(pattern='keyword', path='/mnt/user-data/uploads/')`).")
lines.append("")
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
"""Create a formatted message listing uploaded files.
Args:
new_files: Files uploaded in the current message.
historical_files: Files uploaded in previous messages.
Each file dict may contain an optional ``outline`` key — a list of
``{title, line}`` dicts extracted from the converted Markdown file.
Returns:
Formatted string inside <uploaded_files> tags.
"""
lines = ["<uploaded_files>"]
lines.append("The following files were uploaded in this message:")
lines.append("")
if new_files:
for file in new_files:
self._format_file_entry(file, lines)
else:
lines.append("(empty)")
lines.append("")
if historical_files:
lines.append("The following files were uploaded in previous messages and are still available:")
lines.append("")
for file in historical_files:
self._format_file_entry(file, lines)
lines.append("To work with these files:")
lines.append("- Read from the file first — use the outline line numbers and `read_file` to locate relevant sections.")
lines.append("- Use `grep` to search for keywords when you are not sure which section to look at")
lines.append(" (e.g. `grep(pattern='revenue', path='/mnt/user-data/uploads/')`).")
lines.append("- Use `glob` to find files by name pattern")
lines.append(" (e.g. `glob(pattern='**/*.md', path='/mnt/user-data/uploads/')`).")
lines.append("- Only fall back to web search if the file content is clearly insufficient to answer the question.")
lines.append("</uploaded_files>")
return "\n".join(lines)
def _files_from_kwargs(self, message: HumanMessage, uploads_dir: Path | None = None) -> list[dict] | None:
"""Extract file info from message additional_kwargs.files.
The frontend sends uploaded file metadata in additional_kwargs.files
after a successful upload. Each entry has: filename, size (bytes),
path (virtual path), status.
Args:
message: The human message to inspect.
uploads_dir: Physical uploads directory used to verify file existence.
When provided, entries whose files no longer exist are skipped.
Returns:
List of file dicts with virtual paths, or None if the field is absent or empty.
"""
kwargs_files = (message.additional_kwargs or {}).get("files")
if not isinstance(kwargs_files, list) or not kwargs_files:
return None
files = []
for f in kwargs_files:
if not isinstance(f, dict):
continue
filename = f.get("filename") or ""
if not filename or Path(filename).name != filename:
continue
if uploads_dir is not None and not (uploads_dir / filename).is_file():
continue
files.append(
{
"filename": filename,
"size": int(f.get("size") or 0),
"path": f"/mnt/user-data/uploads/{filename}",
"extension": Path(filename).suffix,
}
)
return files if files else None
@override
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
"""Inject uploaded files information before agent execution.
New files come from the current message's additional_kwargs.files.
Historical files are scanned from the thread's uploads directory,
excluding the new ones.
Prepends <uploaded_files> context to the last human message content.
The original additional_kwargs (including files metadata) is preserved
on the updated message so the frontend can read it from the stream.
Args:
state: Current agent state.
runtime: Runtime context containing thread_id.
Returns:
State updates including uploaded files list.
"""
messages = list(state.get("messages", []))
if not messages:
return None
last_message_index = len(messages) - 1
last_message = messages[last_message_index]
if not isinstance(last_message, HumanMessage):
return None
# Resolve uploads directory for existence checks
thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
try:
from langgraph.config import get_config
thread_id = get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
pass # get_config() raises outside a runnable context (e.g. unit tests)
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
# Collect historical files from the uploads directory (all except the new ones)
new_filenames = {f["filename"] for f in new_files}
historical_files: list[dict] = []
if uploads_dir and uploads_dir.exists():
for file_path in sorted(uploads_dir.iterdir()):
if file_path.is_file() and file_path.name not in new_filenames:
stat = file_path.stat()
outline, preview = _extract_outline_for_file(file_path)
historical_files.append(
{
"filename": file_path.name,
"size": stat.st_size,
"path": f"/mnt/user-data/uploads/{file_path.name}",
"extension": file_path.suffix,
"outline": outline,
"outline_preview": preview,
}
)
# Attach outlines to new files as well
if uploads_dir:
for file in new_files:
phys_path = uploads_dir / file["filename"]
outline, preview = _extract_outline_for_file(phys_path)
file["outline"] = outline
file["outline_preview"] = preview
if not new_files and not historical_files:
return None
logger.debug(f"New files: {[f['filename'] for f in new_files]}, historical: {[f['filename'] for f in historical_files]}")
# Create files message and prepend to the last human message content
files_message = self._create_files_message(new_files, historical_files)
# Extract original content - handle both string and list formats
original_content = last_message.content
if isinstance(original_content, str):
# Simple case: string content, just prepend files message
updated_content = f"{files_message}\n\n{original_content}"
elif isinstance(original_content, list):
# Complex case: list content (multimodal), preserve all blocks
# Prepend files message as the first text block
files_block = {"type": "text", "text": f"{files_message}\n\n"}
# Keep all original blocks (including images)
updated_content = [files_block, *original_content]
else:
# Other types, preserve as-is
updated_content = original_content
# Create new message with combined content.
# Preserve additional_kwargs (including files metadata) so the frontend
# can read structured file info from the streamed message.
updated_message = HumanMessage(
content=updated_content,
id=last_message.id,
additional_kwargs=last_message.additional_kwargs,
)
messages[last_message_index] = updated_message
return {
"uploaded_files": new_files,
"messages": messages,
}

View File

@@ -0,0 +1,222 @@
"""Middleware for injecting image details into conversation before LLM call."""
import logging
from typing import override
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadState
logger = logging.getLogger(__name__)
class ViewImageMiddlewareState(ThreadState):
"""Reuse the thread state so reducer-backed keys keep their annotations."""
class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]):
"""Injects image details as a human message before LLM calls when view_image tools have completed.
This middleware:
1. Runs before each LLM call
2. Checks if the last assistant message contains view_image tool calls
3. Verifies all tool calls in that message have been completed (have corresponding ToolMessages)
4. If conditions are met, creates a human message with all viewed image details (including base64 data)
5. Adds the message to state so the LLM can see and analyze the images
This enables the LLM to automatically receive and analyze images that were loaded via view_image tool,
without requiring explicit user prompts to describe the images.
"""
state_schema = ViewImageMiddlewareState
def _get_last_assistant_message(self, messages: list) -> AIMessage | None:
"""Get the last assistant message from the message list.
Args:
messages: List of messages
Returns:
Last AIMessage or None if not found
"""
for msg in reversed(messages):
if isinstance(msg, AIMessage):
return msg
return None
def _has_view_image_tool(self, message: AIMessage) -> bool:
"""Check if the assistant message contains view_image tool calls.
Args:
message: Assistant message to check
Returns:
True if message contains view_image tool calls
"""
if not hasattr(message, "tool_calls") or not message.tool_calls:
return False
return any(tool_call.get("name") == "view_image" for tool_call in message.tool_calls)
def _all_tools_completed(self, messages: list, assistant_msg: AIMessage) -> bool:
"""Check if all tool calls in the assistant message have been completed.
Args:
messages: List of all messages
assistant_msg: The assistant message containing tool calls
Returns:
True if all tool calls have corresponding ToolMessages
"""
if not hasattr(assistant_msg, "tool_calls") or not assistant_msg.tool_calls:
return False
# Get all tool call IDs from the assistant message
tool_call_ids = {tool_call.get("id") for tool_call in assistant_msg.tool_calls if tool_call.get("id")}
# Find the index of the assistant message
try:
assistant_idx = messages.index(assistant_msg)
except ValueError:
return False
# Get all ToolMessages after the assistant message
completed_tool_ids = set()
for msg in messages[assistant_idx + 1 :]:
if isinstance(msg, ToolMessage) and msg.tool_call_id:
completed_tool_ids.add(msg.tool_call_id)
# Check if all tool calls have been completed
return tool_call_ids.issubset(completed_tool_ids)
def _create_image_details_message(self, state: ViewImageMiddlewareState) -> list[str | dict]:
"""Create a formatted message with all viewed image details.
Args:
state: Current state containing viewed_images
Returns:
List of content blocks (text and images) for the HumanMessage
"""
viewed_images = state.get("viewed_images", {})
if not viewed_images:
# Return a properly formatted text block, not a plain string array
return [{"type": "text", "text": "No images have been viewed."}]
# Build the message with image information
content_blocks: list[str | dict] = [{"type": "text", "text": "Here are the images you've viewed:"}]
for image_path, image_data in viewed_images.items():
mime_type = image_data.get("mime_type", "unknown")
base64_data = image_data.get("base64", "")
# Add text description
content_blocks.append({"type": "text", "text": f"\n- **{image_path}** ({mime_type})"})
# Add the actual image data so LLM can "see" it
if base64_data:
content_blocks.append(
{
"type": "image_url",
"image_url": {"url": f"data:{mime_type};base64,{base64_data}"},
}
)
return content_blocks
def _should_inject_image_message(self, state: ViewImageMiddlewareState) -> bool:
"""Determine if we should inject an image details message.
Args:
state: Current state
Returns:
True if we should inject the message
"""
messages = state.get("messages", [])
if not messages:
return False
# Get the last assistant message
last_assistant_msg = self._get_last_assistant_message(messages)
if not last_assistant_msg:
return False
# Check if it has view_image tool calls
if not self._has_view_image_tool(last_assistant_msg):
return False
# Check if all tools have been completed
if not self._all_tools_completed(messages, last_assistant_msg):
return False
# Check if we've already added an image details message
# Look for a human message after the last assistant message that contains image details
assistant_idx = messages.index(last_assistant_msg)
for msg in messages[assistant_idx + 1 :]:
if isinstance(msg, HumanMessage):
content_str = str(msg.content)
if "Here are the images you've viewed" in content_str or "Here are the details of the images you've viewed" in content_str:
# Already added, don't add again
return False
return True
def _inject_image_message(self, state: ViewImageMiddlewareState) -> dict | None:
"""Internal helper to inject image details message.
Args:
state: Current state
Returns:
State update with additional human message, or None if no update needed
"""
if not self._should_inject_image_message(state):
return None
# Create the image details message with text and image content
image_content = self._create_image_details_message(state)
# Create a new human message with mixed content (text + images)
human_msg = HumanMessage(content=image_content)
logger.debug("Injecting image details message with images before LLM call")
# Return state update with the new message
return {"messages": [human_msg]}
@override
def before_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
"""Inject image details message before LLM call if view_image tools have completed (sync version).
This runs before each LLM call, checking if the previous turn included view_image
tool calls that have all completed. If so, it injects a human message with the image
details so the LLM can see and analyze the images.
Args:
state: Current state
runtime: Runtime context (unused but required by interface)
Returns:
State update with additional human message, or None if no update needed
"""
return self._inject_image_message(state)
@override
async def abefore_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
"""Inject image details message before LLM call if view_image tools have completed (async version).
This runs before each LLM call, checking if the previous turn included view_image
tool calls that have all completed. If so, it injects a human message with the image
details so the LLM can see and analyze the images.
Args:
state: Current state
runtime: Runtime context (unused but required by interface)
Returns:
State update with additional human message, or None if no update needed
"""
return self._inject_image_message(state)

View File

@@ -0,0 +1,55 @@
from typing import Annotated, NotRequired, TypedDict
from langchain.agents import AgentState
class SandboxState(TypedDict):
sandbox_id: NotRequired[str | None]
class ThreadDataState(TypedDict):
workspace_path: NotRequired[str | None]
uploads_path: NotRequired[str | None]
outputs_path: NotRequired[str | None]
class ViewedImageData(TypedDict):
base64: str
mime_type: str
def merge_artifacts(existing: list[str] | None, new: list[str] | None) -> list[str]:
"""Reducer for artifacts list - merges and deduplicates artifacts."""
if existing is None:
return new or []
if new is None:
return existing
# Use dict.fromkeys to deduplicate while preserving order
return list(dict.fromkeys(existing + new))
def merge_viewed_images(existing: dict[str, ViewedImageData] | None, new: dict[str, ViewedImageData] | None) -> dict[str, ViewedImageData]:
"""Reducer for viewed_images dict - merges image dictionaries.
Special case: If new is an empty dict {}, it clears the existing images.
This allows middlewares to clear the viewed_images state after processing.
"""
if existing is None:
return new or {}
if new is None:
return existing
# Special case: empty dict means clear all viewed images
if len(new) == 0:
return {}
# Merge dictionaries, new values override existing ones for same keys
return {**existing, **new}
class ThreadState(AgentState):
sandbox: NotRequired[SandboxState | None]
thread_data: NotRequired[ThreadDataState | None]
title: NotRequired[str | None]
artifacts: Annotated[list[str], merge_artifacts]
todos: NotRequired[list | None]
uploaded_files: NotRequired[list[dict] | None]
viewed_images: Annotated[dict[str, ViewedImageData], merge_viewed_images] # image_path -> {base64, mime_type}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,39 @@
"""Hard-disable shim for native (unhardened) community web tools.
In this hardened DeerFlow build, the only allowed web access surface is
``deerflow.community.searx.tools`` (web_search, web_fetch, image_search).
The legacy providers below have been deliberately stubbed out and will
raise on import so that misconfiguration cannot silently fall back to
unsanitized output:
- ddg_search (DuckDuckGo)
- tavily (Tavily)
- exa (Exa)
- firecrawl (Firecrawl)
- jina_ai (Jina Reader)
- infoquest (InfoQuest)
- image_search (DDG image fallback)
If you really need one of these back, undo the change in the matching
``community/<name>/tools.py`` and audit the call site for prompt-injection
hardening first.
"""
from __future__ import annotations
class NativeWebToolDisabledError(RuntimeError):
"""Raised when a hard-disabled native web tool is imported or invoked."""
_MESSAGE_TEMPLATE = (
"Native web tool '{provider}' is disabled in this hardened DeerFlow build. "
"Use 'deerflow.community.searx.tools' (web_search / web_fetch / image_search) instead. "
"If you really need '{provider}', re-enable it in "
"deerflow/community/{provider}/tools.py and harden it first."
)
def reject_native_provider(provider: str) -> None:
"""Raise a clear error pointing the operator at the searx replacement."""
raise NativeWebToolDisabledError(_MESSAGE_TEMPLATE.format(provider=provider))

View File

@@ -0,0 +1,15 @@
from .aio_sandbox import AioSandbox
from .aio_sandbox_provider import AioSandboxProvider
from .backend import SandboxBackend
from .local_backend import LocalContainerBackend
from .remote_backend import RemoteSandboxBackend
from .sandbox_info import SandboxInfo
__all__ = [
"AioSandbox",
"AioSandboxProvider",
"LocalContainerBackend",
"RemoteSandboxBackend",
"SandboxBackend",
"SandboxInfo",
]

View File

@@ -0,0 +1,232 @@
import base64
import logging
import shlex
import threading
import uuid
from agent_sandbox import Sandbox as AioSandboxClient
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.search import GrepMatch, path_matches, should_ignore_path, truncate_line
logger = logging.getLogger(__name__)
_ERROR_OBSERVATION_SIGNATURE = "'ErrorObservation' object has no attribute 'exit_code'"
class AioSandbox(Sandbox):
"""Sandbox implementation using the agent-infra/sandbox Docker container.
This sandbox connects to a running AIO sandbox container via HTTP API.
A threading lock serializes shell commands to prevent concurrent requests
from corrupting the container's single persistent session (see #1433).
"""
def __init__(self, id: str, base_url: str, home_dir: str | None = None):
"""Initialize the AIO sandbox.
Args:
id: Unique identifier for this sandbox instance.
base_url: URL of the sandbox API (e.g., http://localhost:8080).
home_dir: Home directory inside the sandbox. If None, will be fetched from the sandbox.
"""
super().__init__(id)
self._base_url = base_url
self._client = AioSandboxClient(base_url=base_url, timeout=600)
self._home_dir = home_dir
self._lock = threading.Lock()
@property
def base_url(self) -> str:
return self._base_url
@property
def home_dir(self) -> str:
"""Get the home directory inside the sandbox."""
if self._home_dir is None:
context = self._client.sandbox.get_context()
self._home_dir = context.home_dir
return self._home_dir
def execute_command(self, command: str) -> str:
"""Execute a shell command in the sandbox.
Uses a lock to serialize concurrent requests. The AIO sandbox
container maintains a single persistent shell session that
corrupts when hit with concurrent exec_command calls (returns
``ErrorObservation`` instead of real output). If corruption is
detected despite the lock (e.g. multiple processes sharing a
sandbox), the command is retried on a fresh session.
Args:
command: The command to execute.
Returns:
The output of the command.
"""
with self._lock:
try:
result = self._client.shell.exec_command(command=command)
output = result.data.output if result.data else ""
if output and _ERROR_OBSERVATION_SIGNATURE in output:
logger.warning("ErrorObservation detected in sandbox output, retrying with a fresh session")
fresh_id = str(uuid.uuid4())
result = self._client.shell.exec_command(command=command, id=fresh_id)
output = result.data.output if result.data else ""
return output if output else "(no output)"
except Exception as e:
logger.error(f"Failed to execute command in sandbox: {e}")
return f"Error: {e}"
def read_file(self, path: str) -> str:
"""Read the content of a file in the sandbox.
Args:
path: The absolute path of the file to read.
Returns:
The content of the file.
"""
try:
result = self._client.file.read_file(file=path)
return result.data.content if result.data else ""
except Exception as e:
logger.error(f"Failed to read file in sandbox: {e}")
return f"Error: {e}"
def list_dir(self, path: str, max_depth: int = 2) -> list[str]:
"""List the contents of a directory in the sandbox.
Args:
path: The absolute path of the directory to list.
max_depth: The maximum depth to traverse. Default is 2.
Returns:
The contents of the directory.
"""
with self._lock:
try:
result = self._client.shell.exec_command(command=f"find {shlex.quote(path)} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500")
output = result.data.output if result.data else ""
if output:
return [line.strip() for line in output.strip().split("\n") if line.strip()]
return []
except Exception as e:
logger.error(f"Failed to list directory in sandbox: {e}")
return []
def write_file(self, path: str, content: str, append: bool = False) -> None:
"""Write content to a file in the sandbox.
Args:
path: The absolute path of the file to write to.
content: The text content to write to the file.
append: Whether to append the content to the file.
"""
with self._lock:
try:
if append:
existing = self.read_file(path)
if not existing.startswith("Error:"):
content = existing + content
self._client.file.write_file(file=path, content=content)
except Exception as e:
logger.error(f"Failed to write file in sandbox: {e}")
raise
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
if not include_dirs:
result = self._client.file.find_files(path=path, glob=pattern)
files = result.data.files if result.data and result.data.files else []
filtered = [file_path for file_path in files if not should_ignore_path(file_path)]
truncated = len(filtered) > max_results
return filtered[:max_results], truncated
result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
entries = result.data.files if result.data and result.data.files else []
matches: list[str] = []
root_path = path.rstrip("/") or "/"
root_prefix = root_path if root_path == "/" else f"{root_path}/"
for entry in entries:
if entry.path != root_path and not entry.path.startswith(root_prefix):
continue
if should_ignore_path(entry.path):
continue
rel_path = entry.path[len(root_path) :].lstrip("/")
if path_matches(pattern, rel_path):
matches.append(entry.path)
if len(matches) >= max_results:
return matches, True
return matches, False
def grep(
self,
path: str,
pattern: str,
*,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = 100,
) -> tuple[list[GrepMatch], bool]:
import re as _re
regex_source = _re.escape(pattern) if literal else pattern
# Validate the pattern locally so an invalid regex raises re.error
# (caught by grep_tool's except re.error handler) rather than a
# generic remote API error.
_re.compile(regex_source, 0 if case_sensitive else _re.IGNORECASE)
regex = regex_source if case_sensitive else f"(?i){regex_source}"
if glob is not None:
find_result = self._client.file.find_files(path=path, glob=glob)
candidate_paths = find_result.data.files if find_result.data and find_result.data.files else []
else:
list_result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
entries = list_result.data.files if list_result.data and list_result.data.files else []
candidate_paths = [entry.path for entry in entries if not entry.is_directory]
matches: list[GrepMatch] = []
truncated = False
for file_path in candidate_paths:
if should_ignore_path(file_path):
continue
search_result = self._client.file.search_in_file(file=file_path, regex=regex)
data = search_result.data
if data is None:
continue
line_numbers = data.line_numbers or []
matched_lines = data.matches or []
for line_number, line in zip(line_numbers, matched_lines):
matches.append(
GrepMatch(
path=file_path,
line_number=line_number if isinstance(line_number, int) else 0,
line=truncate_line(line),
)
)
if len(matches) >= max_results:
truncated = True
return matches, truncated
return matches, truncated
def update_file(self, path: str, content: bytes) -> None:
"""Update a file with binary content in the sandbox.
Args:
path: The absolute path of the file to update.
content: The binary content to write to the file.
"""
with self._lock:
try:
base64_content = base64.b64encode(content).decode("utf-8")
self._client.file.write_file(file=path, content=base64_content, encoding="base64")
except Exception as e:
logger.error(f"Failed to update file in sandbox: {e}")
raise

View File

@@ -0,0 +1,694 @@
"""AIO Sandbox Provider — orchestrates sandbox lifecycle with pluggable backends.
This provider composes:
- SandboxBackend: how sandboxes are provisioned (local container vs remote/K8s)
The provider itself handles:
- In-process caching for fast repeated access
- Idle timeout management
- Graceful shutdown with signal handling
- Mount computation (thread-specific, skills)
"""
import atexit
import hashlib
import logging
import os
import signal
import threading
import time
import uuid
try:
import fcntl
except ImportError: # pragma: no cover - Windows fallback
fcntl = None # type: ignore[assignment]
import msvcrt
from deerflow.config import get_app_config
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider
from .aio_sandbox import AioSandbox
from .backend import SandboxBackend, wait_for_sandbox_ready
from .local_backend import LocalContainerBackend
from .remote_backend import RemoteSandboxBackend
from .sandbox_info import SandboxInfo
logger = logging.getLogger(__name__)
# Default configuration
DEFAULT_IMAGE = "enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest"
DEFAULT_PORT = 8080
DEFAULT_CONTAINER_PREFIX = "deer-flow-sandbox"
DEFAULT_IDLE_TIMEOUT = 600 # 10 minutes in seconds
DEFAULT_REPLICAS = 3 # Maximum concurrent sandbox containers
IDLE_CHECK_INTERVAL = 60 # Check every 60 seconds
def _lock_file_exclusive(lock_file) -> None:
if fcntl is not None:
fcntl.flock(lock_file, fcntl.LOCK_EX)
return
lock_file.seek(0)
msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1)
def _unlock_file(lock_file) -> None:
if fcntl is not None:
fcntl.flock(lock_file, fcntl.LOCK_UN)
return
lock_file.seek(0)
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
class AioSandboxProvider(SandboxProvider):
"""Sandbox provider that manages containers running the AIO sandbox.
Architecture:
This provider composes a SandboxBackend (how to provision), enabling:
- Local Docker/Apple Container mode (auto-start containers)
- Remote/K8s mode (connect to pre-existing sandbox URL)
Configuration options in config.yaml under sandbox:
use: deerflow.community.aio_sandbox:AioSandboxProvider
image: <container image>
port: 8080 # Base port for local containers
container_prefix: deer-flow-sandbox
idle_timeout: 600 # Idle timeout in seconds (0 to disable)
replicas: 3 # Max concurrent sandbox containers (LRU eviction when exceeded)
mounts: # Volume mounts for local containers
- host_path: /path/on/host
container_path: /path/in/container
read_only: false
environment: # Environment variables for containers
NODE_ENV: production
API_KEY: $MY_API_KEY
"""
def __init__(self):
self._lock = threading.Lock()
self._sandboxes: dict[str, AioSandbox] = {} # sandbox_id -> AioSandbox instance
self._sandbox_infos: dict[str, SandboxInfo] = {} # sandbox_id -> SandboxInfo (for destroy)
self._thread_sandboxes: dict[str, str] = {} # thread_id -> sandbox_id
self._thread_locks: dict[str, threading.Lock] = {} # thread_id -> in-process lock
self._last_activity: dict[str, float] = {} # sandbox_id -> last activity timestamp
# Warm pool: released sandboxes whose containers are still running.
# Maps sandbox_id -> (SandboxInfo, release_timestamp).
# Containers here can be reclaimed quickly (no cold-start) or destroyed
# when replicas capacity is exhausted.
self._warm_pool: dict[str, tuple[SandboxInfo, float]] = {}
self._shutdown_called = False
self._idle_checker_stop = threading.Event()
self._idle_checker_thread: threading.Thread | None = None
self._config = self._load_config()
self._backend: SandboxBackend = self._create_backend()
# Register shutdown handler
atexit.register(self.shutdown)
self._register_signal_handlers()
# Reconcile orphaned containers from previous process lifecycles
self._reconcile_orphans()
# Start idle checker if enabled
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
self._start_idle_checker()
# ── Factory methods ──────────────────────────────────────────────────
def _create_backend(self) -> SandboxBackend:
"""Create the appropriate backend based on configuration.
Selection logic (checked in order):
1. ``provisioner_url`` set → RemoteSandboxBackend (provisioner mode)
Provisioner dynamically creates Pods + Services in k3s.
2. Default → LocalContainerBackend (local mode)
Local provider manages container lifecycle directly (start/stop).
"""
provisioner_url = self._config.get("provisioner_url")
if provisioner_url:
logger.info(f"Using remote sandbox backend with provisioner at {provisioner_url}")
return RemoteSandboxBackend(provisioner_url=provisioner_url)
logger.info("Using local container sandbox backend")
return LocalContainerBackend(
image=self._config["image"],
base_port=self._config["port"],
container_prefix=self._config["container_prefix"],
config_mounts=self._config["mounts"],
environment=self._config["environment"],
)
# ── Configuration ────────────────────────────────────────────────────
def _load_config(self) -> dict:
"""Load sandbox configuration from app config."""
config = get_app_config()
sandbox_config = config.sandbox
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
replicas = getattr(sandbox_config, "replicas", None)
return {
"image": sandbox_config.image or DEFAULT_IMAGE,
"port": sandbox_config.port or DEFAULT_PORT,
"container_prefix": sandbox_config.container_prefix or DEFAULT_CONTAINER_PREFIX,
"idle_timeout": idle_timeout if idle_timeout is not None else DEFAULT_IDLE_TIMEOUT,
"replicas": replicas if replicas is not None else DEFAULT_REPLICAS,
"mounts": sandbox_config.mounts or [],
"environment": self._resolve_env_vars(sandbox_config.environment or {}),
# provisioner URL for dynamic pod management (e.g. http://provisioner:8002)
"provisioner_url": getattr(sandbox_config, "provisioner_url", None) or "",
}
@staticmethod
def _resolve_env_vars(env_config: dict[str, str]) -> dict[str, str]:
"""Resolve environment variable references (values starting with $)."""
resolved = {}
for key, value in env_config.items():
if isinstance(value, str) and value.startswith("$"):
env_name = value[1:]
resolved[key] = os.environ.get(env_name, "")
else:
resolved[key] = str(value)
return resolved
# ── Startup reconciliation ────────────────────────────────────────────
def _reconcile_orphans(self) -> None:
"""Reconcile orphaned containers left by previous process lifecycles.
On startup, enumerate all running containers matching our prefix
and adopt them all into the warm pool. The idle checker will reclaim
containers that nobody re-acquires within ``idle_timeout``.
All containers are adopted unconditionally because we cannot
distinguish "orphaned" from "actively used by another process"
based on age alone — ``idle_timeout`` represents inactivity, not
uptime. Adopting into the warm pool and letting the idle checker
decide avoids destroying containers that a concurrent process may
still be using.
This closes the fundamental gap where in-memory state loss (process
restart, crash, SIGKILL) leaves Docker containers running forever.
"""
try:
running = self._backend.list_running()
except Exception as e:
logger.warning(f"Failed to enumerate running containers during startup reconciliation: {e}")
return
if not running:
return
current_time = time.time()
adopted = 0
for info in running:
age = current_time - info.created_at if info.created_at > 0 else float("inf")
# Single lock acquisition per container: atomic check-and-insert.
# Avoids a TOCTOU window between the "already tracked?" check and
# the warm-pool insert.
with self._lock:
if info.sandbox_id in self._sandboxes or info.sandbox_id in self._warm_pool:
continue
self._warm_pool[info.sandbox_id] = (info, current_time)
adopted += 1
logger.info(f"Adopted container {info.sandbox_id} into warm pool (age: {age:.0f}s)")
logger.info(f"Startup reconciliation complete: {adopted} adopted into warm pool, {len(running)} total found")
# ── Deterministic ID ─────────────────────────────────────────────────
@staticmethod
def _deterministic_sandbox_id(thread_id: str) -> str:
"""Generate a deterministic sandbox ID from a thread ID.
Ensures all processes derive the same sandbox_id for a given thread,
enabling cross-process sandbox discovery without shared memory.
"""
return hashlib.sha256(thread_id.encode()).hexdigest()[:8]
# ── Mount helpers ────────────────────────────────────────────────────
def _get_extra_mounts(self, thread_id: str | None) -> list[tuple[str, str, bool]]:
"""Collect all extra mounts for a sandbox (thread-specific + skills)."""
mounts: list[tuple[str, str, bool]] = []
if thread_id:
mounts.extend(self._get_thread_mounts(thread_id))
logger.info(f"Adding thread mounts for thread {thread_id}: {mounts}")
skills_mount = self._get_skills_mount()
if skills_mount:
mounts.append(skills_mount)
logger.info(f"Adding skills mount: {skills_mount}")
return mounts
@staticmethod
def _get_thread_mounts(thread_id: str) -> list[tuple[str, str, bool]]:
"""Get volume mounts for a thread's data directories.
Creates directories if they don't exist (lazy initialization).
Mount sources use host_base_dir so that when running inside Docker with a
mounted Docker socket (DooD), the host Docker daemon can resolve the paths.
"""
paths = get_paths()
paths.ensure_thread_dirs(thread_id)
return [
(paths.host_sandbox_work_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
(paths.host_sandbox_uploads_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
(paths.host_sandbox_outputs_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
# ACP workspace: read-only inside the sandbox (lead agent reads results;
# the ACP subprocess writes from the host side, not from within the container).
(paths.host_acp_workspace_dir(thread_id), "/mnt/acp-workspace", True),
]
@staticmethod
def _get_skills_mount() -> tuple[str, str, bool] | None:
"""Get the skills directory mount configuration.
Mount source uses DEER_FLOW_HOST_SKILLS_PATH when running inside Docker (DooD)
so the host Docker daemon can resolve the path.
"""
try:
config = get_app_config()
skills_path = config.skills.get_skills_path()
container_path = config.skills.container_path
if skills_path.exists():
# When running inside Docker with DooD, use host-side skills path.
host_skills = os.environ.get("DEER_FLOW_HOST_SKILLS_PATH") or str(skills_path)
return (host_skills, container_path, True) # Read-only for security
except Exception as e:
logger.warning(f"Could not setup skills mount: {e}")
return None
# ── Idle timeout management ──────────────────────────────────────────
def _start_idle_checker(self) -> None:
"""Start the background thread that checks for idle sandboxes."""
self._idle_checker_thread = threading.Thread(
target=self._idle_checker_loop,
name="sandbox-idle-checker",
daemon=True,
)
self._idle_checker_thread.start()
logger.info(f"Started idle checker thread (timeout: {self._config.get('idle_timeout', DEFAULT_IDLE_TIMEOUT)}s)")
def _idle_checker_loop(self) -> None:
idle_timeout = self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT)
while not self._idle_checker_stop.wait(timeout=IDLE_CHECK_INTERVAL):
try:
self._cleanup_idle_sandboxes(idle_timeout)
except Exception as e:
logger.error(f"Error in idle checker loop: {e}")
def _cleanup_idle_sandboxes(self, idle_timeout: float) -> None:
current_time = time.time()
active_to_destroy = []
warm_to_destroy: list[tuple[str, SandboxInfo]] = []
with self._lock:
# Active sandboxes: tracked via _last_activity
for sandbox_id, last_activity in self._last_activity.items():
idle_duration = current_time - last_activity
if idle_duration > idle_timeout:
active_to_destroy.append(sandbox_id)
logger.info(f"Sandbox {sandbox_id} idle for {idle_duration:.1f}s, marking for destroy")
# Warm pool: tracked via release_timestamp stored in _warm_pool
for sandbox_id, (info, release_ts) in list(self._warm_pool.items()):
warm_duration = current_time - release_ts
if warm_duration > idle_timeout:
warm_to_destroy.append((sandbox_id, info))
del self._warm_pool[sandbox_id]
logger.info(f"Warm-pool sandbox {sandbox_id} idle for {warm_duration:.1f}s, marking for destroy")
# Destroy active sandboxes (re-verify still idle before acting)
for sandbox_id in active_to_destroy:
try:
# Re-verify the sandbox is still idle under the lock before destroying.
# Between the snapshot above and here, the sandbox may have been
# re-acquired (last_activity updated) or already released/destroyed.
with self._lock:
last_activity = self._last_activity.get(sandbox_id)
if last_activity is None:
# Already released or destroyed by another path — skip.
logger.info(f"Sandbox {sandbox_id} already gone before idle destroy, skipping")
continue
if (time.time() - last_activity) < idle_timeout:
# Re-acquired (activity updated) since the snapshot — skip.
logger.info(f"Sandbox {sandbox_id} was re-acquired before idle destroy, skipping")
continue
logger.info(f"Destroying idle sandbox {sandbox_id}")
self.destroy(sandbox_id)
except Exception as e:
logger.error(f"Failed to destroy idle sandbox {sandbox_id}: {e}")
# Destroy warm-pool sandboxes (already removed from _warm_pool under lock above)
for sandbox_id, info in warm_to_destroy:
try:
self._backend.destroy(info)
logger.info(f"Destroyed idle warm-pool sandbox {sandbox_id}")
except Exception as e:
logger.error(f"Failed to destroy idle warm-pool sandbox {sandbox_id}: {e}")
# ── Signal handling ──────────────────────────────────────────────────
def _register_signal_handlers(self) -> None:
"""Register signal handlers for graceful shutdown.
Handles SIGTERM, SIGINT, and SIGHUP (terminal close) to ensure
sandbox containers are cleaned up even when the user closes the terminal.
"""
self._original_sigterm = signal.getsignal(signal.SIGTERM)
self._original_sigint = signal.getsignal(signal.SIGINT)
self._original_sighup = signal.getsignal(signal.SIGHUP) if hasattr(signal, "SIGHUP") else None
def signal_handler(signum, frame):
self.shutdown()
if signum == signal.SIGTERM:
original = self._original_sigterm
elif hasattr(signal, "SIGHUP") and signum == signal.SIGHUP:
original = self._original_sighup
else:
original = self._original_sigint
if callable(original):
original(signum, frame)
elif original == signal.SIG_DFL:
signal.signal(signum, signal.SIG_DFL)
signal.raise_signal(signum)
try:
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if hasattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, signal_handler)
except ValueError:
logger.debug("Could not register signal handlers (not main thread)")
# ── Thread locking (in-process) ──────────────────────────────────────
def _get_thread_lock(self, thread_id: str) -> threading.Lock:
"""Get or create an in-process lock for a specific thread_id."""
with self._lock:
if thread_id not in self._thread_locks:
self._thread_locks[thread_id] = threading.Lock()
return self._thread_locks[thread_id]
# ── Core: acquire / get / release / shutdown ─────────────────────────
def acquire(self, thread_id: str | None = None) -> str:
"""Acquire a sandbox environment and return its ID.
For the same thread_id, this method will return the same sandbox_id
across multiple turns, multiple processes, and (with shared storage)
multiple pods.
Thread-safe with both in-process and cross-process locking.
Args:
thread_id: Optional thread ID for thread-specific configurations.
Returns:
The ID of the acquired sandbox environment.
"""
if thread_id:
thread_lock = self._get_thread_lock(thread_id)
with thread_lock:
return self._acquire_internal(thread_id)
else:
return self._acquire_internal(thread_id)
def _acquire_internal(self, thread_id: str | None) -> str:
"""Internal sandbox acquisition with two-layer consistency.
Layer 1: In-process cache (fastest, covers same-process repeated access)
Layer 2: Backend discovery (covers containers started by other processes;
sandbox_id is deterministic from thread_id so no shared state file
is needed — any process can derive the same container name)
"""
# ── Layer 1: In-process cache (fast path) ──
if thread_id:
with self._lock:
if thread_id in self._thread_sandboxes:
existing_id = self._thread_sandboxes[thread_id]
if existing_id in self._sandboxes:
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}")
self._last_activity[existing_id] = time.time()
return existing_id
else:
del self._thread_sandboxes[thread_id]
# Deterministic ID for thread-specific, random for anonymous
sandbox_id = self._deterministic_sandbox_id(thread_id) if thread_id else str(uuid.uuid4())[:8]
# ── Layer 1.5: Warm pool (container still running, no cold-start) ──
if thread_id:
with self._lock:
if sandbox_id in self._warm_pool:
info, _ = self._warm_pool.pop(sandbox_id)
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
self._sandboxes[sandbox_id] = sandbox
self._sandbox_infos[sandbox_id] = info
self._last_activity[sandbox_id] = time.time()
self._thread_sandboxes[thread_id] = sandbox_id
logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
return sandbox_id
# ── Layer 2: Backend discovery + create (protected by cross-process lock) ──
# Use a file lock so that two processes racing to create the same sandbox
# for the same thread_id serialize here: the second process will discover
# the container started by the first instead of hitting a name-conflict.
if thread_id:
return self._discover_or_create_with_lock(thread_id, sandbox_id)
return self._create_sandbox(thread_id, sandbox_id)
def _discover_or_create_with_lock(self, thread_id: str, sandbox_id: str) -> str:
"""Discover an existing sandbox or create a new one under a cross-process file lock.
The file lock serializes concurrent sandbox creation for the same thread_id
across multiple processes, preventing container-name conflicts.
"""
paths = get_paths()
paths.ensure_thread_dirs(thread_id)
lock_path = paths.thread_dir(thread_id) / f"{sandbox_id}.lock"
with open(lock_path, "a", encoding="utf-8") as lock_file:
locked = False
try:
_lock_file_exclusive(lock_file)
locked = True
# Re-check in-process caches under the file lock in case another
# thread in this process won the race while we were waiting.
with self._lock:
if thread_id in self._thread_sandboxes:
existing_id = self._thread_sandboxes[thread_id]
if existing_id in self._sandboxes:
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id} (post-lock check)")
self._last_activity[existing_id] = time.time()
return existing_id
if sandbox_id in self._warm_pool:
info, _ = self._warm_pool.pop(sandbox_id)
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
self._sandboxes[sandbox_id] = sandbox
self._sandbox_infos[sandbox_id] = info
self._last_activity[sandbox_id] = time.time()
self._thread_sandboxes[thread_id] = sandbox_id
logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} (post-lock check)")
return sandbox_id
# Backend discovery: another process may have created the container.
discovered = self._backend.discover(sandbox_id)
if discovered is not None:
sandbox = AioSandbox(id=discovered.sandbox_id, base_url=discovered.sandbox_url)
with self._lock:
self._sandboxes[discovered.sandbox_id] = sandbox
self._sandbox_infos[discovered.sandbox_id] = discovered
self._last_activity[discovered.sandbox_id] = time.time()
self._thread_sandboxes[thread_id] = discovered.sandbox_id
logger.info(f"Discovered existing sandbox {discovered.sandbox_id} for thread {thread_id} at {discovered.sandbox_url}")
return discovered.sandbox_id
return self._create_sandbox(thread_id, sandbox_id)
finally:
if locked:
_unlock_file(lock_file)
def _evict_oldest_warm(self) -> str | None:
"""Destroy the oldest container in the warm pool to free capacity.
Returns:
The evicted sandbox_id, or None if warm pool is empty.
"""
with self._lock:
if not self._warm_pool:
return None
oldest_id = min(self._warm_pool, key=lambda sid: self._warm_pool[sid][1])
info, _ = self._warm_pool.pop(oldest_id)
try:
self._backend.destroy(info)
logger.info(f"Destroyed warm-pool sandbox {oldest_id}")
except Exception as e:
logger.error(f"Failed to destroy warm-pool sandbox {oldest_id}: {e}")
return None
return oldest_id
def _create_sandbox(self, thread_id: str | None, sandbox_id: str) -> str:
"""Create a new sandbox via the backend.
Args:
thread_id: Optional thread ID.
sandbox_id: The sandbox ID to use.
Returns:
The sandbox_id.
Raises:
RuntimeError: If sandbox creation or readiness check fails.
"""
extra_mounts = self._get_extra_mounts(thread_id)
# Enforce replicas: only warm-pool containers count toward eviction budget.
# Active sandboxes are in use by live threads and must not be forcibly stopped.
replicas = self._config.get("replicas", DEFAULT_REPLICAS)
with self._lock:
total = len(self._sandboxes) + len(self._warm_pool)
if total >= replicas:
evicted = self._evict_oldest_warm()
if evicted:
logger.info(f"Evicted warm-pool sandbox {evicted} to stay within replicas={replicas}")
else:
# All slots are occupied by active sandboxes — proceed anyway and log.
# The replicas limit is a soft cap; we never forcibly stop a container
# that is actively serving a thread.
logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit")
info = self._backend.create(thread_id, sandbox_id, extra_mounts=extra_mounts or None)
# Wait for sandbox to be ready
if not wait_for_sandbox_ready(info.sandbox_url, timeout=60):
self._backend.destroy(info)
raise RuntimeError(f"Sandbox {sandbox_id} failed to become ready within timeout at {info.sandbox_url}")
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
with self._lock:
self._sandboxes[sandbox_id] = sandbox
self._sandbox_infos[sandbox_id] = info
self._last_activity[sandbox_id] = time.time()
if thread_id:
self._thread_sandboxes[thread_id] = sandbox_id
logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
return sandbox_id
def get(self, sandbox_id: str) -> Sandbox | None:
"""Get a sandbox by ID. Updates last activity timestamp.
Args:
sandbox_id: The ID of the sandbox.
Returns:
The sandbox instance if found, None otherwise.
"""
with self._lock:
sandbox = self._sandboxes.get(sandbox_id)
if sandbox is not None:
self._last_activity[sandbox_id] = time.time()
return sandbox
def release(self, sandbox_id: str) -> None:
"""Release a sandbox from active use into the warm pool.
The container is kept running so it can be reclaimed quickly by the same
thread on its next turn without a cold-start. The container will only be
stopped when the replicas limit forces eviction or during shutdown.
Args:
sandbox_id: The ID of the sandbox to release.
"""
info = None
thread_ids_to_remove: list[str] = []
with self._lock:
self._sandboxes.pop(sandbox_id, None)
info = self._sandbox_infos.pop(sandbox_id, None)
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids_to_remove:
del self._thread_sandboxes[tid]
self._last_activity.pop(sandbox_id, None)
# Park in warm pool — container keeps running
if info and sandbox_id not in self._warm_pool:
self._warm_pool[sandbox_id] = (info, time.time())
logger.info(f"Released sandbox {sandbox_id} to warm pool (container still running)")
def destroy(self, sandbox_id: str) -> None:
"""Destroy a sandbox: stop the container and free all resources.
Unlike release(), this actually stops the container. Use this for
explicit cleanup, capacity-driven eviction, or shutdown.
Args:
sandbox_id: The ID of the sandbox to destroy.
"""
info = None
thread_ids_to_remove: list[str] = []
with self._lock:
self._sandboxes.pop(sandbox_id, None)
info = self._sandbox_infos.pop(sandbox_id, None)
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids_to_remove:
del self._thread_sandboxes[tid]
self._last_activity.pop(sandbox_id, None)
# Also pull from warm pool if it was parked there
if info is None and sandbox_id in self._warm_pool:
info, _ = self._warm_pool.pop(sandbox_id)
else:
self._warm_pool.pop(sandbox_id, None)
if info:
self._backend.destroy(info)
logger.info(f"Destroyed sandbox {sandbox_id}")
def shutdown(self) -> None:
"""Shutdown all sandboxes. Thread-safe and idempotent."""
with self._lock:
if self._shutdown_called:
return
self._shutdown_called = True
sandbox_ids = list(self._sandboxes.keys())
warm_items = list(self._warm_pool.items())
self._warm_pool.clear()
# Stop idle checker
self._idle_checker_stop.set()
if self._idle_checker_thread is not None and self._idle_checker_thread.is_alive():
self._idle_checker_thread.join(timeout=5)
logger.info("Stopped idle checker thread")
logger.info(f"Shutting down {len(sandbox_ids)} active + {len(warm_items)} warm-pool sandbox(es)")
for sandbox_id in sandbox_ids:
try:
self.destroy(sandbox_id)
except Exception as e:
logger.error(f"Failed to destroy sandbox {sandbox_id} during shutdown: {e}")
for sandbox_id, (info, _) in warm_items:
try:
self._backend.destroy(info)
logger.info(f"Destroyed warm-pool sandbox {sandbox_id} during shutdown")
except Exception as e:
logger.error(f"Failed to destroy warm-pool sandbox {sandbox_id} during shutdown: {e}")

View File

@@ -0,0 +1,114 @@
"""Abstract base class for sandbox provisioning backends."""
from __future__ import annotations
import logging
import time
from abc import ABC, abstractmethod
import requests
from .sandbox_info import SandboxInfo
logger = logging.getLogger(__name__)
def wait_for_sandbox_ready(sandbox_url: str, timeout: int = 30) -> bool:
"""Poll sandbox health endpoint until ready or timeout.
Args:
sandbox_url: URL of the sandbox (e.g. http://k3s:30001).
timeout: Maximum time to wait in seconds.
Returns:
True if sandbox is ready, False otherwise.
"""
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(f"{sandbox_url}/v1/sandbox", timeout=5)
if response.status_code == 200:
return True
except requests.exceptions.RequestException:
pass
time.sleep(1)
return False
class SandboxBackend(ABC):
"""Abstract base for sandbox provisioning backends.
Two implementations:
- LocalContainerBackend: starts Docker/Apple Container locally, manages ports
- RemoteSandboxBackend: connects to a pre-existing URL (K8s service, external)
"""
@abstractmethod
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
"""Create/provision a new sandbox.
Args:
thread_id: Thread ID for which the sandbox is being created. Useful for backends that want to organize sandboxes by thread.
sandbox_id: Deterministic sandbox identifier.
extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples.
Ignored by backends that don't manage containers (e.g., remote).
Returns:
SandboxInfo with connection details.
"""
...
@abstractmethod
def destroy(self, info: SandboxInfo) -> None:
"""Destroy/cleanup a sandbox and release its resources.
Args:
info: The sandbox metadata to destroy.
"""
...
@abstractmethod
def is_alive(self, info: SandboxInfo) -> bool:
"""Quick check whether a sandbox is still alive.
This should be a lightweight check (e.g., container inspect)
rather than a full health check.
Args:
info: The sandbox metadata to check.
Returns:
True if the sandbox appears to be alive.
"""
...
@abstractmethod
def discover(self, sandbox_id: str) -> SandboxInfo | None:
"""Try to discover an existing sandbox by its deterministic ID.
Used for cross-process recovery: when another process started a sandbox,
this process can discover it by the deterministic container name or URL.
Args:
sandbox_id: The deterministic sandbox ID to look for.
Returns:
SandboxInfo if found and healthy, None otherwise.
"""
...
def list_running(self) -> list[SandboxInfo]:
"""Enumerate all running sandboxes managed by this backend.
Used for startup reconciliation: when the process restarts, it needs
to discover containers started by previous processes so they can be
adopted into the warm pool or destroyed if idle too long.
The default implementation returns an empty list, which is correct
for backends that don't manage local containers (e.g., RemoteSandboxBackend
delegates lifecycle to the provisioner which handles its own cleanup).
Returns:
A list of SandboxInfo for all currently running sandboxes.
"""
return []

View File

@@ -0,0 +1,530 @@
"""Local container backend for sandbox provisioning.
Manages sandbox containers using Docker or Apple Container on the local machine.
Handles container lifecycle, port allocation, and cross-process container discovery.
"""
from __future__ import annotations
import json
import logging
import os
import subprocess
from datetime import datetime
from deerflow.utils.network import get_free_port, release_port
from .backend import SandboxBackend, wait_for_sandbox_ready
from .sandbox_info import SandboxInfo
logger = logging.getLogger(__name__)
def _parse_docker_timestamp(raw: str) -> float:
"""Parse Docker's ISO 8601 timestamp into a Unix epoch float.
Docker returns timestamps with nanosecond precision and a trailing ``Z``
(e.g. ``2026-04-08T01:22:50.123456789Z``). Python's ``fromisoformat``
accepts at most microseconds and (pre-3.11) does not accept ``Z``, so the
string is normalized before parsing. Returns ``0.0`` on empty input or
parse failure so callers can use ``0.0`` as a sentinel for "unknown age".
"""
if not raw:
return 0.0
try:
s = raw.strip()
if "." in s:
dot_pos = s.index(".")
tz_start = dot_pos + 1
while tz_start < len(s) and s[tz_start].isdigit():
tz_start += 1
frac = s[dot_pos + 1 : tz_start][:6] # truncate to microseconds
tz_suffix = s[tz_start:]
s = s[: dot_pos + 1] + frac + tz_suffix
if s.endswith("Z"):
s = s[:-1] + "+00:00"
return datetime.fromisoformat(s).timestamp()
except (ValueError, TypeError) as e:
logger.debug(f"Could not parse docker timestamp {raw!r}: {e}")
return 0.0
def _extract_host_port(inspect_entry: dict, container_port: int) -> int | None:
"""Extract the host port mapped to ``container_port/tcp`` from a docker inspect entry.
Returns None if the container has no port mapping for that port.
"""
try:
ports = (inspect_entry.get("NetworkSettings") or {}).get("Ports") or {}
bindings = ports.get(f"{container_port}/tcp") or []
if bindings:
host_port = bindings[0].get("HostPort")
if host_port:
return int(host_port)
except (ValueError, TypeError, AttributeError):
pass
return None
def _format_container_mount(runtime: str, host_path: str, container_path: str, read_only: bool) -> list[str]:
"""Format a bind-mount argument for the selected runtime.
Docker's ``-v host:container`` syntax is ambiguous for Windows drive-letter
paths like ``D:/...`` because ``:`` is both the drive separator and the
volume separator. Use ``--mount type=bind,...`` for Docker to avoid that
parsing ambiguity. Apple Container keeps using ``-v``.
"""
if runtime == "docker":
mount_spec = f"type=bind,src={host_path},dst={container_path}"
if read_only:
mount_spec += ",readonly"
return ["--mount", mount_spec]
mount_spec = f"{host_path}:{container_path}"
if read_only:
mount_spec += ":ro"
return ["-v", mount_spec]
class LocalContainerBackend(SandboxBackend):
"""Backend that manages sandbox containers locally using Docker or Apple Container.
On macOS, automatically prefers Apple Container if available, otherwise falls back to Docker.
On other platforms, uses Docker.
Features:
- Deterministic container naming for cross-process discovery
- Port allocation with thread-safe utilities
- Container lifecycle management (start/stop with --rm)
- Support for volume mounts and environment variables
"""
def __init__(
self,
*,
image: str,
base_port: int,
container_prefix: str,
config_mounts: list,
environment: dict[str, str],
):
"""Initialize the local container backend.
Args:
image: Container image to use.
base_port: Base port number to start searching for free ports.
container_prefix: Prefix for container names (e.g., "deer-flow-sandbox").
config_mounts: Volume mount configurations from config (list of VolumeMountConfig).
environment: Environment variables to inject into containers.
"""
self._image = image
self._base_port = base_port
self._container_prefix = container_prefix
self._config_mounts = config_mounts
self._environment = environment
self._runtime = self._detect_runtime()
@property
def runtime(self) -> str:
"""The detected container runtime ("docker" or "container")."""
return self._runtime
def _detect_runtime(self) -> str:
"""Detect which container runtime to use.
On macOS, prefer Apple Container if available, otherwise fall back to Docker.
On other platforms, use Docker.
Returns:
"container" for Apple Container, "docker" for Docker.
"""
import platform
if platform.system() == "Darwin":
try:
result = subprocess.run(
["container", "--version"],
capture_output=True,
text=True,
check=True,
timeout=5,
)
logger.info(f"Detected Apple Container: {result.stdout.strip()}")
return "container"
except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired):
logger.info("Apple Container not available, falling back to Docker")
return "docker"
# ── SandboxBackend interface ──────────────────────────────────────────
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
"""Start a new container and return its connection info.
Args:
thread_id: Thread ID for which the sandbox is being created. Useful for backends that want to organize sandboxes by thread.
sandbox_id: Deterministic sandbox identifier (used in container name).
extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples.
Returns:
SandboxInfo with container details.
Raises:
RuntimeError: If the container fails to start.
"""
container_name = f"{self._container_prefix}-{sandbox_id}"
# Retry loop: if Docker rejects the port (e.g. a stale container still
# holds the binding after a process restart), skip that port and try the
# next one. The socket-bind check in get_free_port mirrors Docker's
# 0.0.0.0 bind, but Docker's port-release can be slightly asynchronous,
# so a reactive fallback here ensures we always make progress.
_next_start = self._base_port
container_id: str | None = None
port: int = 0
for _attempt in range(10):
port = get_free_port(start_port=_next_start)
try:
container_id = self._start_container(container_name, port, extra_mounts)
break
except RuntimeError as exc:
release_port(port)
err = str(exc)
err_lower = err.lower()
# Port already bound: skip this port and retry with the next one.
if "port is already allocated" in err or "address already in use" in err_lower:
logger.warning(f"Port {port} rejected by Docker (already allocated), retrying with next port")
_next_start = port + 1
continue
# Container-name conflict: another process may have already started
# the deterministic sandbox container for this sandbox_id. Try to
# discover and adopt the existing container instead of failing.
if "is already in use by container" in err_lower or "conflict. the container name" in err_lower:
logger.warning(f"Container name {container_name} already in use, attempting to discover existing sandbox instance")
existing = self.discover(sandbox_id)
if existing is not None:
return existing
raise
else:
raise RuntimeError("Could not start sandbox container: all candidate ports are already allocated by Docker")
# When running inside Docker (DooD), sandbox containers are reachable via
# host.docker.internal rather than localhost (they run on the host daemon).
sandbox_host = os.environ.get("DEER_FLOW_SANDBOX_HOST", "localhost")
return SandboxInfo(
sandbox_id=sandbox_id,
sandbox_url=f"http://{sandbox_host}:{port}",
container_name=container_name,
container_id=container_id,
)
def destroy(self, info: SandboxInfo) -> None:
"""Stop the container and release its port."""
# Prefer container_id, fall back to container_name (both accepted by docker stop).
# This ensures containers discovered via list_running() (which only has the name)
# can also be stopped.
stop_target = info.container_id or info.container_name
if stop_target:
self._stop_container(stop_target)
# Extract port from sandbox_url for release
try:
from urllib.parse import urlparse
port = urlparse(info.sandbox_url).port
if port:
release_port(port)
except Exception:
pass
def is_alive(self, info: SandboxInfo) -> bool:
"""Check if the container is still running (lightweight, no HTTP)."""
if info.container_name:
return self._is_container_running(info.container_name)
return False
def discover(self, sandbox_id: str) -> SandboxInfo | None:
"""Discover an existing container by its deterministic name.
Checks if a container with the expected name is running, retrieves its
port, and verifies it responds to health checks.
Args:
sandbox_id: The deterministic sandbox ID (determines container name).
Returns:
SandboxInfo if container found and healthy, None otherwise.
"""
container_name = f"{self._container_prefix}-{sandbox_id}"
if not self._is_container_running(container_name):
return None
port = self._get_container_port(container_name)
if port is None:
return None
sandbox_host = os.environ.get("DEER_FLOW_SANDBOX_HOST", "localhost")
sandbox_url = f"http://{sandbox_host}:{port}"
if not wait_for_sandbox_ready(sandbox_url, timeout=5):
return None
return SandboxInfo(
sandbox_id=sandbox_id,
sandbox_url=sandbox_url,
container_name=container_name,
)
def list_running(self) -> list[SandboxInfo]:
"""Enumerate all running containers matching the configured prefix.
Uses a single ``docker ps`` call to list container names, then a
single batched ``docker inspect`` call to retrieve creation timestamp
and port mapping for all containers at once. Total subprocess calls:
2 (down from 2N+1 in the naive per-container approach).
Note: Docker's ``--filter name=`` performs *substring* matching,
so a secondary ``startswith`` check is applied to ensure only
containers with the exact prefix are included.
Containers without port mappings are still included (with empty
sandbox_url) so that startup reconciliation can adopt orphans
regardless of their port state.
"""
# Step 1: enumerate container names via docker ps
try:
result = subprocess.run(
[
self._runtime,
"ps",
"--filter",
f"name={self._container_prefix}-",
"--format",
"{{.Names}}",
],
capture_output=True,
text=True,
timeout=10,
)
if result.returncode != 0:
stderr = (result.stderr or "").strip()
logger.warning(
"Failed to list running containers with %s ps (returncode=%s, stderr=%s)",
self._runtime,
result.returncode,
stderr or "<empty>",
)
return []
if not result.stdout.strip():
return []
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError, OSError) as e:
logger.warning(f"Failed to list running containers: {e}")
return []
# Filter to names matching our exact prefix (docker filter is substring-based)
container_names = [name.strip() for name in result.stdout.strip().splitlines() if name.strip().startswith(self._container_prefix + "-")]
if not container_names:
return []
# Step 2: batched docker inspect — single subprocess call for all containers
inspections = self._batch_inspect(container_names)
infos: list[SandboxInfo] = []
sandbox_host = os.environ.get("DEER_FLOW_SANDBOX_HOST", "localhost")
for container_name in container_names:
data = inspections.get(container_name)
if data is None:
# Container disappeared between ps and inspect, or inspect failed
continue
created_at, host_port = data
sandbox_id = container_name[len(self._container_prefix) + 1 :]
sandbox_url = f"http://{sandbox_host}:{host_port}" if host_port else ""
infos.append(
SandboxInfo(
sandbox_id=sandbox_id,
sandbox_url=sandbox_url,
container_name=container_name,
created_at=created_at,
)
)
logger.info(f"Found {len(infos)} running sandbox container(s)")
return infos
def _batch_inspect(self, container_names: list[str]) -> dict[str, tuple[float, int | None]]:
"""Batch-inspect containers in a single subprocess call.
Returns a mapping of ``container_name -> (created_at, host_port)``.
Missing containers or parse failures are silently dropped from the result.
"""
if not container_names:
return {}
try:
result = subprocess.run(
[self._runtime, "inspect", *container_names],
capture_output=True,
text=True,
timeout=15,
)
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError, OSError) as e:
logger.warning(f"Failed to batch-inspect containers: {e}")
return {}
if result.returncode != 0:
stderr = (result.stderr or "").strip()
logger.warning(
"Failed to batch-inspect containers with %s inspect (returncode=%s, stderr=%s)",
self._runtime,
result.returncode,
stderr or "<empty>",
)
return {}
try:
payload = json.loads(result.stdout or "[]")
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse docker inspect output as JSON: {e}")
return {}
out: dict[str, tuple[float, int | None]] = {}
for entry in payload:
# ``Name`` is prefixed with ``/`` in the docker inspect response
name = (entry.get("Name") or "").lstrip("/")
if not name:
continue
created_at = _parse_docker_timestamp(entry.get("Created", ""))
host_port = _extract_host_port(entry, 8080)
out[name] = (created_at, host_port)
return out
# ── Container operations ─────────────────────────────────────────────
def _start_container(
self,
container_name: str,
port: int,
extra_mounts: list[tuple[str, str, bool]] | None = None,
) -> str:
"""Start a new container.
Args:
container_name: Name for the container.
port: Host port to map to container port 8080.
extra_mounts: Additional volume mounts.
Returns:
The container ID.
Raises:
RuntimeError: If container fails to start.
"""
cmd = [self._runtime, "run"]
# Docker-specific security options
if self._runtime == "docker":
cmd.extend(["--security-opt", "seccomp=unconfined"])
cmd.extend(
[
"--rm",
"-d",
"-p",
f"{port}:8080",
"--name",
container_name,
]
)
# Environment variables
for key, value in self._environment.items():
cmd.extend(["-e", f"{key}={value}"])
# Config-level volume mounts
for mount in self._config_mounts:
cmd.extend(
_format_container_mount(
self._runtime,
mount.host_path,
mount.container_path,
mount.read_only,
)
)
# Extra mounts (thread-specific, skills, etc.)
if extra_mounts:
for host_path, container_path, read_only in extra_mounts:
cmd.extend(
_format_container_mount(
self._runtime,
host_path,
container_path,
read_only,
)
)
cmd.append(self._image)
logger.info(f"Starting container using {self._runtime}: {' '.join(cmd)}")
try:
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
container_id = result.stdout.strip()
logger.info(f"Started container {container_name} (ID: {container_id}) using {self._runtime}")
return container_id
except subprocess.CalledProcessError as e:
logger.error(f"Failed to start container using {self._runtime}: {e.stderr}")
raise RuntimeError(f"Failed to start sandbox container: {e.stderr}")
def _stop_container(self, container_id: str) -> None:
"""Stop a container (--rm ensures automatic removal)."""
try:
subprocess.run(
[self._runtime, "stop", container_id],
capture_output=True,
text=True,
check=True,
)
logger.info(f"Stopped container {container_id} using {self._runtime}")
except subprocess.CalledProcessError as e:
logger.warning(f"Failed to stop container {container_id}: {e.stderr}")
def _is_container_running(self, container_name: str) -> bool:
"""Check if a named container is currently running.
This enables cross-process container discovery — any process can detect
containers started by another process via the deterministic container name.
"""
try:
result = subprocess.run(
[self._runtime, "inspect", "-f", "{{.State.Running}}", container_name],
capture_output=True,
text=True,
timeout=5,
)
return result.returncode == 0 and result.stdout.strip().lower() == "true"
except (subprocess.CalledProcessError, subprocess.TimeoutExpired):
return False
def _get_container_port(self, container_name: str) -> int | None:
"""Get the host port of a running container.
Args:
container_name: The container name to inspect.
Returns:
The host port mapped to container port 8080, or None if not found.
"""
try:
result = subprocess.run(
[self._runtime, "port", container_name, "8080"],
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0 and result.stdout.strip():
# Output format: "0.0.0.0:PORT" or ":::PORT"
port_str = result.stdout.strip().split(":")[-1]
return int(port_str)
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, ValueError):
pass
return None

View File

@@ -0,0 +1,156 @@
"""Remote sandbox backend — delegates Pod lifecycle to the provisioner service.
The provisioner dynamically creates per-sandbox-id Pods + NodePort Services
in k3s. The backend accesses sandbox pods directly via ``k3s:{NodePort}``.
Architecture:
┌────────────┐ HTTP ┌─────────────┐ K8s API ┌──────────┐
│ this file │ ──────▸ │ provisioner │ ────────▸ │ k3s │
│ (backend) │ │ :8002 │ │ :6443 │
└────────────┘ └─────────────┘ └─────┬────┘
│ creates
┌─────────────┐ ┌─────▼──────┐
│ backend │ ────────▸ │ sandbox │
│ │ direct │ Pod(s) │
└─────────────┘ k3s:NPort └────────────┘
"""
from __future__ import annotations
import logging
import requests
from .backend import SandboxBackend
from .sandbox_info import SandboxInfo
logger = logging.getLogger(__name__)
class RemoteSandboxBackend(SandboxBackend):
"""Backend that delegates sandbox lifecycle to the provisioner service.
All Pod creation, destruction, and discovery are handled by the
provisioner. This backend is a thin HTTP client.
Typical config.yaml::
sandbox:
use: deerflow.community.aio_sandbox:AioSandboxProvider
provisioner_url: http://provisioner:8002
"""
def __init__(self, provisioner_url: str):
"""Initialize with the provisioner service URL.
Args:
provisioner_url: URL of the provisioner service
(e.g., ``http://provisioner:8002``).
"""
self._provisioner_url = provisioner_url.rstrip("/")
@property
def provisioner_url(self) -> str:
return self._provisioner_url
# ── SandboxBackend interface ──────────────────────────────────────────
def create(
self,
thread_id: str,
sandbox_id: str,
extra_mounts: list[tuple[str, str, bool]] | None = None,
) -> SandboxInfo:
"""Create a sandbox Pod + Service via the provisioner.
Calls ``POST /api/sandboxes`` which creates a dedicated Pod +
NodePort Service in k3s.
"""
return self._provisioner_create(thread_id, sandbox_id, extra_mounts)
def destroy(self, info: SandboxInfo) -> None:
"""Destroy a sandbox Pod + Service via the provisioner."""
self._provisioner_destroy(info.sandbox_id)
def is_alive(self, info: SandboxInfo) -> bool:
"""Check whether the sandbox Pod is running."""
return self._provisioner_is_alive(info.sandbox_id)
def discover(self, sandbox_id: str) -> SandboxInfo | None:
"""Discover an existing sandbox via the provisioner.
Calls ``GET /api/sandboxes/{sandbox_id}`` and returns info if
the Pod exists.
"""
return self._provisioner_discover(sandbox_id)
# ── Provisioner API calls ─────────────────────────────────────────────
def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
"""POST /api/sandboxes → create Pod + Service."""
try:
resp = requests.post(
f"{self._provisioner_url}/api/sandboxes",
json={
"sandbox_id": sandbox_id,
"thread_id": thread_id,
},
timeout=30,
)
resp.raise_for_status()
data = resp.json()
logger.info(f"Provisioner created sandbox {sandbox_id}: sandbox_url={data['sandbox_url']}")
return SandboxInfo(
sandbox_id=sandbox_id,
sandbox_url=data["sandbox_url"],
)
except requests.RequestException as exc:
logger.error(f"Provisioner create failed for {sandbox_id}: {exc}")
raise RuntimeError(f"Provisioner create failed: {exc}") from exc
def _provisioner_destroy(self, sandbox_id: str) -> None:
"""DELETE /api/sandboxes/{sandbox_id} → destroy Pod + Service."""
try:
resp = requests.delete(
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
timeout=15,
)
if resp.ok:
logger.info(f"Provisioner destroyed sandbox {sandbox_id}")
else:
logger.warning(f"Provisioner destroy returned {resp.status_code}: {resp.text}")
except requests.RequestException as exc:
logger.warning(f"Provisioner destroy failed for {sandbox_id}: {exc}")
def _provisioner_is_alive(self, sandbox_id: str) -> bool:
"""GET /api/sandboxes/{sandbox_id} → check Pod phase."""
try:
resp = requests.get(
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
timeout=10,
)
if resp.ok:
data = resp.json()
return data.get("status") == "Running"
return False
except requests.RequestException:
return False
def _provisioner_discover(self, sandbox_id: str) -> SandboxInfo | None:
"""GET /api/sandboxes/{sandbox_id} → discover existing sandbox."""
try:
resp = requests.get(
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
timeout=10,
)
if resp.status_code == 404:
return None
resp.raise_for_status()
data = resp.json()
return SandboxInfo(
sandbox_id=sandbox_id,
sandbox_url=data["sandbox_url"],
)
except requests.RequestException as exc:
logger.debug(f"Provisioner discover failed for {sandbox_id}: {exc}")
return None

View File

@@ -0,0 +1,41 @@
"""Sandbox metadata for cross-process discovery and state persistence."""
from __future__ import annotations
import time
from dataclasses import dataclass, field
@dataclass
class SandboxInfo:
"""Persisted sandbox metadata that enables cross-process discovery.
This dataclass holds all the information needed to reconnect to an
existing sandbox from a different process (e.g., gateway vs langgraph,
multiple workers, or across K8s pods with shared storage).
"""
sandbox_id: str
sandbox_url: str # e.g. http://localhost:8080 or http://k3s:30001
container_name: str | None = None # Only for local container backend
container_id: str | None = None # Only for local container backend
created_at: float = field(default_factory=time.time)
def to_dict(self) -> dict:
return {
"sandbox_id": self.sandbox_id,
"sandbox_url": self.sandbox_url,
"container_name": self.container_name,
"container_id": self.container_id,
"created_at": self.created_at,
}
@classmethod
def from_dict(cls, data: dict) -> SandboxInfo:
return cls(
sandbox_id=data["sandbox_id"],
sandbox_url=data.get("sandbox_url", data.get("base_url", "")),
container_name=data.get("container_name"),
container_id=data.get("container_id"),
created_at=data.get("created_at", time.time()),
)

View File

@@ -0,0 +1,3 @@
from .tools import web_search_tool
__all__ = ["web_search_tool"]

View File

@@ -0,0 +1,5 @@
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
from deerflow.community._disabled_native import reject_native_provider
reject_native_provider("ddg_search")

View File

@@ -0,0 +1,5 @@
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
from deerflow.community._disabled_native import reject_native_provider
reject_native_provider("exa")

View File

@@ -0,0 +1,5 @@
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
from deerflow.community._disabled_native import reject_native_provider
reject_native_provider("firecrawl")

View File

@@ -0,0 +1,3 @@
from .tools import image_search_tool
__all__ = ["image_search_tool"]

View File

@@ -0,0 +1,5 @@
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
from deerflow.community._disabled_native import reject_native_provider
reject_native_provider("image_search")

View File

@@ -0,0 +1,5 @@
"""DISABLED: native InfoQuest client. See deerflow.community._disabled_native."""
from deerflow.community._disabled_native import reject_native_provider
reject_native_provider("infoquest")

View File

@@ -0,0 +1,5 @@
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
from deerflow.community._disabled_native import reject_native_provider
reject_native_provider("infoquest")

View File

@@ -0,0 +1,5 @@
"""DISABLED: native Jina client. See deerflow.community._disabled_native."""
from deerflow.community._disabled_native import reject_native_provider
reject_native_provider("jina_ai")

View File

@@ -0,0 +1,5 @@
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
from deerflow.community._disabled_native import reject_native_provider
reject_native_provider("jina_ai")

View File

@@ -0,0 +1,12 @@
"""Hardened SearX-backed web tools (search, fetch, image search).
All results are passed through the deerflow.security pipeline:
- HTML cleaning (strip script/style/iframe/etc.)
- Unicode sanitizer (zero-width chars, control chars, PUA, tag chars)
- Content delimiter wrapping (semantic boundary for the LLM)
These tools are the ONLY web access surface in this hardened build.
The legacy community web providers (ddg_search, tavily, exa, firecrawl,
jina_ai, infoquest, image_search) are deliberately disabled — see
deerflow/community/_disabled_native.py.
"""

View File

@@ -0,0 +1,160 @@
"""Hardened SearX web search, web fetch, and image search tools.
Every external response is sanitized and wrapped in security delimiters
before being returned to the LLM. See deerflow.security for the pipeline.
"""
from __future__ import annotations
from urllib.parse import quote
import httpx
from langchain.tools import tool
from deerflow.config import get_app_config
from deerflow.security.content_delimiter import wrap_untrusted_content
from deerflow.security.html_cleaner import extract_secure_text
from deerflow.security.sanitizer import sanitizer
DEFAULT_SEARX_URL = "http://localhost:8888"
DEFAULT_TIMEOUT = 30.0
DEFAULT_USER_AGENT = "DeerFlow-Hardened/1.0 (+searx)"
def _tool_extra(name: str) -> dict:
"""Read the model_extra dict for a tool config entry, defensively."""
cfg = get_app_config().get_tool_config(name)
if cfg is None:
return {}
return getattr(cfg, "model_extra", {}) or {}
def _searx_url(tool_name: str = "web_search") -> str:
return _tool_extra(tool_name).get("searx_url", DEFAULT_SEARX_URL)
def _http_get(url: str, params: dict, timeout: float = DEFAULT_TIMEOUT) -> dict:
"""GET a SearX endpoint and return parsed JSON. Raises on transport/HTTP error."""
with httpx.Client(headers={"User-Agent": DEFAULT_USER_AGENT}) as client:
response = client.get(url, params=params, timeout=timeout)
response.raise_for_status()
return response.json()
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str, max_results: int = 10) -> str:
"""Search the web via the private hardened SearX instance.
All results are sanitized against prompt-injection vectors and
wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>> markers.
Args:
query: Search keywords.
max_results: Maximum results to return (capped by config).
"""
extra = _tool_extra("web_search")
cap = int(extra.get("max_results", 10))
searx_url = extra.get("searx_url", DEFAULT_SEARX_URL)
limit = max(1, min(int(max_results), cap))
try:
data = _http_get(
f"{searx_url}/search",
{"q": quote(query), "format": "json"},
)
except Exception as exc:
return wrap_untrusted_content({"error": f"Search failed: {exc}"})
results = []
for item in data.get("results", [])[:limit]:
results.append(
{
"title": sanitizer.sanitize(item.get("title", ""), max_length=200),
"url": item.get("url", ""),
"content": sanitizer.sanitize(item.get("content", ""), max_length=500),
}
)
return wrap_untrusted_content(
{
"query": query,
"total_results": len(results),
"results": results,
}
)
@tool("web_fetch", parse_docstring=True)
async def web_fetch_tool(url: str, max_chars: int = 10000) -> str:
"""Fetch a web page and return sanitized visible text.
Dangerous HTML elements (script, style, iframe, form, ...) are stripped,
invisible Unicode is removed, and the result is wrapped in security markers.
Only call this for URLs returned by web_search or supplied directly by the
user — do not invent URLs.
Args:
url: Absolute URL to fetch (must include scheme).
max_chars: Maximum number of characters to return.
"""
extra = _tool_extra("web_fetch")
cap = int(extra.get("max_chars", max_chars))
limit = max(256, min(int(max_chars), cap))
try:
async with httpx.AsyncClient(
headers={"User-Agent": DEFAULT_USER_AGENT},
follow_redirects=True,
) as client:
response = await client.get(url, timeout=DEFAULT_TIMEOUT)
response.raise_for_status()
html = response.text
except Exception as exc:
return wrap_untrusted_content({"error": f"Fetch failed: {exc}", "url": url})
raw_text = extract_secure_text(html)
clean_text = sanitizer.sanitize(raw_text, max_length=limit)
return wrap_untrusted_content({"url": url, "content": clean_text})
@tool("image_search", parse_docstring=True)
def image_search_tool(query: str, max_results: int = 5) -> str:
"""Search for images via the private hardened SearX instance.
Returns sanitized title/url pairs (no inline image data). Wrapped in
security delimiters.
Args:
query: Image search keywords.
max_results: Maximum number of images to return.
"""
extra = _tool_extra("image_search")
cap = int(extra.get("max_results", 5))
searx_url = extra.get("searx_url", _searx_url("web_search"))
limit = max(1, min(int(max_results), cap))
try:
data = _http_get(
f"{searx_url}/search",
{"q": quote(query), "format": "json", "categories": "images"},
)
except Exception as exc:
return wrap_untrusted_content({"error": f"Image search failed: {exc}"})
results = []
for item in data.get("results", [])[:limit]:
results.append(
{
"title": sanitizer.sanitize(item.get("title", ""), max_length=200),
"url": item.get("url", ""),
"thumbnail": item.get("thumbnail_src") or item.get("img_src", ""),
}
)
return wrap_untrusted_content(
{
"query": query,
"total_results": len(results),
"results": results,
}
)

View File

@@ -0,0 +1,5 @@
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
from deerflow.community._disabled_native import reject_native_provider
reject_native_provider("tavily")

View File

@@ -0,0 +1,30 @@
from .app_config import get_app_config
from .extensions_config import ExtensionsConfig, get_extensions_config
from .memory_config import MemoryConfig, get_memory_config
from .paths import Paths, get_paths
from .skill_evolution_config import SkillEvolutionConfig
from .skills_config import SkillsConfig
from .tracing_config import (
get_enabled_tracing_providers,
get_explicitly_enabled_tracing_providers,
get_tracing_config,
is_tracing_enabled,
validate_enabled_tracing_providers,
)
__all__ = [
"get_app_config",
"SkillEvolutionConfig",
"Paths",
"get_paths",
"SkillsConfig",
"ExtensionsConfig",
"get_extensions_config",
"MemoryConfig",
"get_memory_config",
"get_tracing_config",
"get_explicitly_enabled_tracing_providers",
"get_enabled_tracing_providers",
"is_tracing_enabled",
"validate_enabled_tracing_providers",
]

View File

@@ -0,0 +1,51 @@
"""ACP (Agent Client Protocol) agent configuration loaded from config.yaml."""
import logging
from collections.abc import Mapping
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class ACPAgentConfig(BaseModel):
"""Configuration for a single ACP-compatible agent."""
command: str = Field(description="Command to launch the ACP agent subprocess")
args: list[str] = Field(default_factory=list, description="Additional command arguments")
env: dict[str, str] = Field(default_factory=dict, description="Environment variables to inject into the agent subprocess. Values starting with $ are resolved from host environment variables.")
description: str = Field(description="Description of the agent's capabilities (shown in tool description)")
model: str | None = Field(default=None, description="Model hint passed to the agent (optional)")
auto_approve_permissions: bool = Field(
default=False,
description=(
"When True, DeerFlow automatically approves all ACP permission requests from this agent "
"(allow_once preferred over allow_always). When False (default), all permission requests "
"are denied — the agent must be configured to operate without requesting permissions."
),
)
_acp_agents: dict[str, ACPAgentConfig] = {}
def get_acp_agents() -> dict[str, ACPAgentConfig]:
"""Get the currently configured ACP agents.
Returns:
Mapping of agent name -> ACPAgentConfig. Empty dict if no ACP agents are configured.
"""
return _acp_agents
def load_acp_config_from_dict(config_dict: Mapping[str, Mapping[str, object]] | None) -> None:
"""Load ACP agent configuration from a dictionary (typically from config.yaml).
Args:
config_dict: Mapping of agent name -> config fields.
"""
global _acp_agents
if config_dict is None:
config_dict = {}
_acp_agents = {name: ACPAgentConfig(**cfg) for name, cfg in config_dict.items()}
logger.info("ACP config loaded: %d agent(s): %s", len(_acp_agents), list(_acp_agents.keys()))

View File

@@ -0,0 +1,125 @@
"""Configuration and loaders for custom agents."""
import logging
import re
from typing import Any
import yaml
from pydantic import BaseModel
from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__)
SOUL_FILENAME = "SOUL.md"
AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
class AgentConfig(BaseModel):
"""Configuration for a custom agent."""
name: str
description: str = ""
model: str | None = None
tool_groups: list[str] | None = None
# skills controls which skills are loaded into the agent's prompt:
# - None (or omitted): load all enabled skills (default fallback behavior)
# - [] (explicit empty list): disable all skills
# - ["skill1", "skill2"]: load only the specified skills
skills: list[str] | None = None
def load_agent_config(name: str | None) -> AgentConfig | None:
"""Load the custom or default agent's config from its directory.
Args:
name: The agent name.
Returns:
AgentConfig instance.
Raises:
FileNotFoundError: If the agent directory or config.yaml does not exist.
ValueError: If config.yaml cannot be parsed.
"""
if name is None:
return None
if not AGENT_NAME_PATTERN.match(name):
raise ValueError(f"Invalid agent name '{name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
agent_dir = get_paths().agent_dir(name)
config_file = agent_dir / "config.yaml"
if not agent_dir.exists():
raise FileNotFoundError(f"Agent directory not found: {agent_dir}")
if not config_file.exists():
raise FileNotFoundError(f"Agent config not found: {config_file}")
try:
with open(config_file, encoding="utf-8") as f:
data: dict[str, Any] = yaml.safe_load(f) or {}
except yaml.YAMLError as e:
raise ValueError(f"Failed to parse agent config {config_file}: {e}") from e
# Ensure name is set from directory name if not in file
if "name" not in data:
data["name"] = name
# Strip unknown fields before passing to Pydantic (e.g. legacy prompt_file)
known_fields = set(AgentConfig.model_fields.keys())
data = {k: v for k, v in data.items() if k in known_fields}
return AgentConfig(**data)
def load_agent_soul(agent_name: str | None) -> str | None:
"""Read the SOUL.md file for a custom agent, if it exists.
SOUL.md defines the agent's personality, values, and behavioral guardrails.
It is injected into the lead agent's system prompt as additional context.
Args:
agent_name: The name of the agent or None for the default agent.
Returns:
The SOUL.md content as a string, or None if the file does not exist.
"""
agent_dir = get_paths().agent_dir(agent_name) if agent_name else get_paths().base_dir
soul_path = agent_dir / SOUL_FILENAME
if not soul_path.exists():
return None
content = soul_path.read_text(encoding="utf-8").strip()
return content or None
def list_custom_agents() -> list[AgentConfig]:
"""Scan the agents directory and return all valid custom agents.
Returns:
List of AgentConfig for each valid agent directory found.
"""
agents_dir = get_paths().agents_dir
if not agents_dir.exists():
return []
agents: list[AgentConfig] = []
for entry in sorted(agents_dir.iterdir()):
if not entry.is_dir():
continue
config_file = entry / "config.yaml"
if not config_file.exists():
logger.debug(f"Skipping {entry.name}: no config.yaml")
continue
try:
agent_cfg = load_agent_config(entry.name)
agents.append(agent_cfg)
except Exception as e:
logger.warning(f"Skipping agent '{entry.name}': {e}")
return agents

View File

@@ -0,0 +1,379 @@
import logging
import os
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Self
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
from deerflow.config.skills_config import SkillsConfig
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict
from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict
from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
from deerflow.config.token_usage_config import TokenUsageConfig
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
load_dotenv()
logger = logging.getLogger(__name__)
def _default_config_candidates() -> tuple[Path, ...]:
"""Return deterministic config.yaml locations without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
return (backend_dir / "config.yaml", repo_root / "config.yaml")
class AppConfig(BaseModel):
"""Config for the DeerFlow application"""
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration")
models: list[ModelConfig] = Field(default_factory=list, description="Available models")
sandbox: SandboxConfig = Field(description="Sandbox configuration")
tools: list[ToolConfig] = Field(default_factory=list, description="Available tools")
tool_groups: list[ToolGroupConfig] = Field(default_factory=list, description="Available tool groups")
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
skill_evolution: SkillEvolutionConfig = Field(default_factory=SkillEvolutionConfig, description="Agent-managed skill evolution configuration")
extensions: ExtensionsConfig = Field(default_factory=ExtensionsConfig, description="Extensions configuration (MCP servers and skills state)")
tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig, description="Tool search / deferred loading configuration")
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
model_config = ConfigDict(extra="allow", frozen=False)
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
@classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path:
"""Resolve the config file path.
Priority:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
"""
if config_path:
path = Path(config_path)
if not Path.exists(path):
raise FileNotFoundError(f"Config file specified by param `config_path` not found at {path}")
return path
elif os.getenv("DEER_FLOW_CONFIG_PATH"):
path = Path(os.getenv("DEER_FLOW_CONFIG_PATH"))
if not Path.exists(path):
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
return path
else:
for path in _default_config_candidates():
if path.exists():
return path
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
@classmethod
def from_file(cls, config_path: str | None = None) -> Self:
"""Load config from YAML file.
See `resolve_config_path` for more details.
Args:
config_path: Path to the config file.
Returns:
AppConfig: The loaded config.
"""
resolved_path = cls.resolve_config_path(config_path)
with open(resolved_path, encoding="utf-8") as f:
config_data = yaml.safe_load(f) or {}
# Check config version before processing
cls._check_config_version(config_data, resolved_path)
config_data = cls.resolve_env_variables(config_data)
# Load title config if present
if "title" in config_data:
load_title_config_from_dict(config_data["title"])
# Load summarization config if present
if "summarization" in config_data:
load_summarization_config_from_dict(config_data["summarization"])
# Load memory config if present
if "memory" in config_data:
load_memory_config_from_dict(config_data["memory"])
# Load subagents config if present
if "subagents" in config_data:
load_subagents_config_from_dict(config_data["subagents"])
# Load tool_search config if present
if "tool_search" in config_data:
load_tool_search_config_from_dict(config_data["tool_search"])
# Load guardrails config if present
if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"])
# Load checkpointer config if present
if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"])
# Load stream bridge config if present
if "stream_bridge" in config_data:
load_stream_bridge_config_from_dict(config_data["stream_bridge"])
# Always refresh ACP agent config so removed entries do not linger across reloads.
load_acp_config_from_dict(config_data.get("acp_agents", {}))
# Load extensions config separately (it's in a different file)
extensions_config = ExtensionsConfig.from_file()
config_data["extensions"] = extensions_config.model_dump()
result = cls.model_validate(config_data)
return result
@classmethod
def _check_config_version(cls, config_data: dict, config_path: Path) -> None:
"""Check if the user's config.yaml is outdated compared to config.example.yaml.
Emits a warning if the user's config_version is lower than the example's.
Missing config_version is treated as version 0 (pre-versioning).
"""
try:
user_version = int(config_data.get("config_version", 0))
except (TypeError, ValueError):
user_version = 0
# Find config.example.yaml by searching config.yaml's directory and its parents
example_path = None
search_dir = config_path.parent
for _ in range(5): # search up to 5 levels
candidate = search_dir / "config.example.yaml"
if candidate.exists():
example_path = candidate
break
parent = search_dir.parent
if parent == search_dir:
break
search_dir = parent
if example_path is None:
return
try:
with open(example_path, encoding="utf-8") as f:
example_data = yaml.safe_load(f)
raw = example_data.get("config_version", 0) if example_data else 0
try:
example_version = int(raw)
except (TypeError, ValueError):
example_version = 0
except Exception:
return
if user_version < example_version:
logger.warning(
"Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to merge new fields into your config.",
user_version,
example_version,
)
@classmethod
def resolve_env_variables(cls, config: Any) -> Any:
"""Recursively resolve environment variables in the config.
Environment variables are resolved using the `os.getenv` function. Example: $OPENAI_API_KEY
Args:
config: The config to resolve environment variables in.
Returns:
The config with environment variables resolved.
"""
if isinstance(config, str):
if config.startswith("$"):
env_value = os.getenv(config[1:])
if env_value is None:
raise ValueError(f"Environment variable {config[1:]} not found for config value {config}")
return env_value
return config
elif isinstance(config, dict):
return {k: cls.resolve_env_variables(v) for k, v in config.items()}
elif isinstance(config, list):
return [cls.resolve_env_variables(item) for item in config]
return config
def get_model_config(self, name: str) -> ModelConfig | None:
"""Get the model config by name.
Args:
name: The name of the model to get the config for.
Returns:
The model config if found, otherwise None.
"""
return next((model for model in self.models if model.name == name), None)
def get_tool_config(self, name: str) -> ToolConfig | None:
"""Get the tool config by name.
Args:
name: The name of the tool to get the config for.
Returns:
The tool config if found, otherwise None.
"""
return next((tool for tool in self.tools if tool.name == name), None)
def get_tool_group_config(self, name: str) -> ToolGroupConfig | None:
"""Get the tool group config by name.
Args:
name: The name of the tool group to get the config for.
Returns:
The tool group config if found, otherwise None.
"""
return next((group for group in self.tool_groups if group.name == name), None)
_app_config: AppConfig | None = None
_app_config_path: Path | None = None
_app_config_mtime: float | None = None
_app_config_is_custom = False
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
def _get_config_mtime(config_path: Path) -> float | None:
"""Get the modification time of a config file if it exists."""
try:
return config_path.stat().st_mtime
except OSError:
return None
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
"""Load config from disk and refresh cache metadata."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
resolved_path = AppConfig.resolve_config_path(config_path)
_app_config = AppConfig.from_file(str(resolved_path))
_app_config_path = resolved_path
_app_config_mtime = _get_config_mtime(resolved_path)
_app_config_is_custom = False
return _app_config
def get_app_config() -> AppConfig:
"""Get the DeerFlow config instance.
Returns a cached singleton instance and automatically reloads it when the
underlying config file path or modification time changes. Use
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
the cache.
"""
global _app_config, _app_config_path, _app_config_mtime
runtime_override = _current_app_config.get()
if runtime_override is not None:
return runtime_override
if _app_config is not None and _app_config_is_custom:
return _app_config
resolved_path = AppConfig.resolve_config_path()
current_mtime = _get_config_mtime(resolved_path)
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
if should_reload:
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
logger.info(
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
_app_config_mtime,
current_mtime,
)
_load_and_cache_app_config(str(resolved_path))
return _app_config
def reload_app_config(config_path: str | None = None) -> AppConfig:
"""Reload the config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded AppConfig instance.
"""
return _load_and_cache_app_config(config_path)
def reset_app_config() -> None:
"""Reset the cached config instance.
This clears the singleton cache, causing the next call to
`get_app_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = None
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = False
def set_app_config(config: AppConfig) -> None:
"""Set a custom config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The AppConfig instance to use.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = config
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = True
def peek_current_app_config() -> AppConfig | None:
"""Return the runtime-scoped AppConfig override, if one is active."""
return _current_app_config.get()
def push_current_app_config(config: AppConfig) -> None:
"""Push a runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
_current_app_config_stack.set(stack + (_current_app_config.get(),))
_current_app_config.set(config)
def pop_current_app_config() -> None:
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
if not stack:
_current_app_config.set(None)
return
previous = stack[-1]
_current_app_config_stack.set(stack[:-1])
_current_app_config.set(previous)

View File

@@ -0,0 +1,46 @@
"""Configuration for LangGraph checkpointer."""
from typing import Literal
from pydantic import BaseModel, Field
CheckpointerType = Literal["memory", "sqlite", "postgres"]
class CheckpointerConfig(BaseModel):
"""Configuration for LangGraph state persistence checkpointer."""
type: CheckpointerType = Field(
description="Checkpointer backend type. "
"'memory' is in-process only (lost on restart). "
"'sqlite' persists to a local file (requires langgraph-checkpoint-sqlite). "
"'postgres' persists to PostgreSQL (requires langgraph-checkpoint-postgres)."
)
connection_string: str | None = Field(
default=None,
description="Connection string for sqlite (file path) or postgres (DSN). "
"Required for sqlite and postgres types. "
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
)
# Global configuration instance — None means no checkpointer is configured.
_checkpointer_config: CheckpointerConfig | None = None
def get_checkpointer_config() -> CheckpointerConfig | None:
"""Get the current checkpointer configuration, or None if not configured."""
return _checkpointer_config
def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
"""Set the checkpointer configuration."""
global _checkpointer_config
_checkpointer_config = config
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
"""Load checkpointer configuration from a dictionary."""
global _checkpointer_config
_checkpointer_config = CheckpointerConfig(**config_dict)

View File

@@ -0,0 +1,256 @@
"""Unified extensions configuration for MCP servers and skills."""
import json
import os
from pathlib import Path
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class McpOAuthConfig(BaseModel):
"""OAuth configuration for an MCP server (HTTP/SSE transports)."""
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
token_url: str = Field(description="OAuth token endpoint URL")
grant_type: Literal["client_credentials", "refresh_token"] = Field(
default="client_credentials",
description="OAuth grant type",
)
client_id: str | None = Field(default=None, description="OAuth client ID")
client_secret: str | None = Field(default=None, description="OAuth client secret")
refresh_token: str | None = Field(default=None, description="OAuth refresh token (for refresh_token grant)")
scope: str | None = Field(default=None, description="OAuth scope")
audience: str | None = Field(default=None, description="OAuth audience (provider-specific)")
token_field: str = Field(default="access_token", description="Field name containing access token in token response")
token_type_field: str = Field(default="token_type", description="Field name containing token type in token response")
expires_in_field: str = Field(default="expires_in", description="Field name containing expiry (seconds) in token response")
default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response")
refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry")
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
model_config = ConfigDict(extra="allow")
class McpServerConfig(BaseModel):
"""Configuration for a single MCP server."""
enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'")
command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)")
args: list[str] = Field(default_factory=list, description="Arguments to pass to the command (for stdio type)")
env: dict[str, str] = Field(default_factory=dict, description="Environment variables for the MCP server")
url: str | None = Field(default=None, description="URL of the MCP server (for sse or http type)")
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)")
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
model_config = ConfigDict(extra="allow")
class SkillStateConfig(BaseModel):
"""Configuration for a single skill's state."""
enabled: bool = Field(default=True, description="Whether this skill is enabled")
class ExtensionsConfig(BaseModel):
"""Unified configuration for MCP servers and skills."""
mcp_servers: dict[str, McpServerConfig] = Field(
default_factory=dict,
description="Map of MCP server name to configuration",
alias="mcpServers",
)
skills: dict[str, SkillStateConfig] = Field(
default_factory=dict,
description="Map of skill name to state configuration",
)
model_config = ConfigDict(extra="allow", populate_by_name=True)
@classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path | None:
"""Resolve the extensions config file path.
Priority:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
3. Otherwise, check for `extensions_config.json` in the current directory, then in the parent directory.
4. For backward compatibility, also check for `mcp_config.json` if `extensions_config.json` is not found.
5. If not found, return None (extensions are optional).
Args:
config_path: Optional path to extensions config file.
Resolution order:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
3. Otherwise, search backend/repository-root defaults for
`extensions_config.json`, then legacy `mcp_config.json`.
Returns:
Path to the extensions config file if found, otherwise None.
"""
if config_path:
path = Path(config_path)
if not path.exists():
raise FileNotFoundError(f"Extensions config file specified by param `config_path` not found at {path}")
return path
elif os.getenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH"):
path = Path(os.getenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH"))
if not path.exists():
raise FileNotFoundError(f"Extensions config file specified by environment variable `DEER_FLOW_EXTENSIONS_CONFIG_PATH` not found at {path}")
return path
else:
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
for path in (
backend_dir / "extensions_config.json",
repo_root / "extensions_config.json",
backend_dir / "mcp_config.json",
repo_root / "mcp_config.json",
):
if path.exists():
return path
# Extensions are optional, so return None if not found
return None
@classmethod
def from_file(cls, config_path: str | None = None) -> "ExtensionsConfig":
"""Load extensions config from JSON file.
See `resolve_config_path` for more details.
Args:
config_path: Path to the extensions config file.
Returns:
ExtensionsConfig: The loaded config, or empty config if file not found.
"""
resolved_path = cls.resolve_config_path(config_path)
if resolved_path is None:
# Return empty config if extensions config file is not found
return cls(mcp_servers={}, skills={})
try:
with open(resolved_path, encoding="utf-8") as f:
config_data = json.load(f)
cls.resolve_env_variables(config_data)
return cls.model_validate(config_data)
except json.JSONDecodeError as e:
raise ValueError(f"Extensions config file at {resolved_path} is not valid JSON: {e}") from e
except Exception as e:
raise RuntimeError(f"Failed to load extensions config from {resolved_path}: {e}") from e
@classmethod
def resolve_env_variables(cls, config: dict[str, Any]) -> dict[str, Any]:
"""Recursively resolve environment variables in the config.
Environment variables are resolved using the `os.getenv` function. Example: $OPENAI_API_KEY
Args:
config: The config to resolve environment variables in.
Returns:
The config with environment variables resolved.
"""
for key, value in config.items():
if isinstance(value, str):
if value.startswith("$"):
env_value = os.getenv(value[1:])
if env_value is None:
# Unresolved placeholder — store empty string so downstream
# consumers (e.g. MCP servers) don't receive the literal "$VAR"
# token as an actual environment value.
config[key] = ""
else:
config[key] = env_value
else:
config[key] = value
elif isinstance(value, dict):
config[key] = cls.resolve_env_variables(value)
elif isinstance(value, list):
config[key] = [cls.resolve_env_variables(item) if isinstance(item, dict) else item for item in value]
return config
def get_enabled_mcp_servers(self) -> dict[str, McpServerConfig]:
"""Get only the enabled MCP servers.
Returns:
Dictionary of enabled MCP servers.
"""
return {name: config for name, config in self.mcp_servers.items() if config.enabled}
def is_skill_enabled(self, skill_name: str, skill_category: str) -> bool:
"""Check if a skill is enabled.
Args:
skill_name: Name of the skill
skill_category: Category of the skill
Returns:
True if enabled, False otherwise
"""
skill_config = self.skills.get(skill_name)
if skill_config is None:
# Default to enable for public & custom skill
return skill_category in ("public", "custom")
return skill_config.enabled
_extensions_config: ExtensionsConfig | None = None
def get_extensions_config() -> ExtensionsConfig:
"""Get the extensions config instance.
Returns a cached singleton instance. Use `reload_extensions_config()` to reload
from file, or `reset_extensions_config()` to clear the cache.
Returns:
The cached ExtensionsConfig instance.
"""
global _extensions_config
if _extensions_config is None:
_extensions_config = ExtensionsConfig.from_file()
return _extensions_config
def reload_extensions_config(config_path: str | None = None) -> ExtensionsConfig:
"""Reload the extensions config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to extensions config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded ExtensionsConfig instance.
"""
global _extensions_config
_extensions_config = ExtensionsConfig.from_file(config_path)
return _extensions_config
def reset_extensions_config() -> None:
"""Reset the cached extensions config instance.
This clears the singleton cache, causing the next call to
`get_extensions_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _extensions_config
_extensions_config = None
def set_extensions_config(config: ExtensionsConfig) -> None:
"""Set a custom extensions config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The ExtensionsConfig instance to use.
"""
global _extensions_config
_extensions_config = config

View File

@@ -0,0 +1,48 @@
"""Configuration for pre-tool-call authorization."""
from pydantic import BaseModel, Field
class GuardrailProviderConfig(BaseModel):
"""Configuration for a guardrail provider."""
use: str = Field(description="Class path (e.g. 'deerflow.guardrails.builtin:AllowlistProvider')")
config: dict = Field(default_factory=dict, description="Provider-specific settings passed as kwargs")
class GuardrailsConfig(BaseModel):
"""Configuration for pre-tool-call authorization.
When enabled, every tool call passes through the configured provider
before execution. The provider receives tool name, arguments, and the
agent's passport reference, and returns an allow/deny decision.
"""
enabled: bool = Field(default=False, description="Enable guardrail middleware")
fail_closed: bool = Field(default=True, description="Block tool calls if provider errors")
passport: str | None = Field(default=None, description="OAP passport path or hosted agent ID")
provider: GuardrailProviderConfig | None = Field(default=None, description="Guardrail provider configuration")
_guardrails_config: GuardrailsConfig | None = None
def get_guardrails_config() -> GuardrailsConfig:
"""Get the guardrails config, returning defaults if not loaded."""
global _guardrails_config
if _guardrails_config is None:
_guardrails_config = GuardrailsConfig()
return _guardrails_config
def load_guardrails_config_from_dict(data: dict) -> GuardrailsConfig:
"""Load guardrails config from a dict (called during AppConfig loading)."""
global _guardrails_config
_guardrails_config = GuardrailsConfig.model_validate(data)
return _guardrails_config
def reset_guardrails_config() -> None:
"""Reset the cached config instance. Used in tests to prevent singleton leaks."""
global _guardrails_config
_guardrails_config = None

View File

@@ -0,0 +1,82 @@
"""Configuration for memory mechanism."""
from pydantic import BaseModel, Field
class MemoryConfig(BaseModel):
"""Configuration for global memory mechanism."""
enabled: bool = Field(
default=True,
description="Whether to enable memory mechanism",
)
storage_path: str = Field(
default="",
description=(
"Path to store memory data. "
"If empty, defaults to `{base_dir}/memory.json` (see Paths.memory_file). "
"Absolute paths are used as-is. "
"Relative paths are resolved against `Paths.base_dir` "
"(not the backend working directory). "
"Note: if you previously set this to `.deer-flow/memory.json`, "
"the file will now be resolved as `{base_dir}/.deer-flow/memory.json`; "
"migrate existing data or use an absolute path to preserve the old location."
),
)
storage_class: str = Field(
default="deerflow.agents.memory.storage.FileMemoryStorage",
description="The class path for memory storage provider",
)
debounce_seconds: int = Field(
default=30,
ge=1,
le=300,
description="Seconds to wait before processing queued updates (debounce)",
)
model_name: str | None = Field(
default=None,
description="Model name to use for memory updates (None = use default model)",
)
max_facts: int = Field(
default=100,
ge=10,
le=500,
description="Maximum number of facts to store",
)
fact_confidence_threshold: float = Field(
default=0.7,
ge=0.0,
le=1.0,
description="Minimum confidence threshold for storing facts",
)
injection_enabled: bool = Field(
default=True,
description="Whether to inject memory into system prompt",
)
max_injection_tokens: int = Field(
default=2000,
ge=100,
le=8000,
description="Maximum tokens to use for memory injection",
)
# Global configuration instance
_memory_config: MemoryConfig = MemoryConfig()
def get_memory_config() -> MemoryConfig:
"""Get the current memory configuration."""
return _memory_config
def set_memory_config(config: MemoryConfig) -> None:
"""Set the memory configuration."""
global _memory_config
_memory_config = config
def load_memory_config_from_dict(config_dict: dict) -> None:
"""Load memory configuration from a dictionary."""
global _memory_config
_memory_config = MemoryConfig(**config_dict)

View File

@@ -0,0 +1,41 @@
from pydantic import BaseModel, ConfigDict, Field
class ModelConfig(BaseModel):
"""Config section for a model"""
name: str = Field(..., description="Unique name for the model")
display_name: str | None = Field(..., default_factory=lambda: None, description="Display name for the model")
description: str | None = Field(..., default_factory=lambda: None, description="Description for the model")
use: str = Field(
...,
description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)",
)
model: str = Field(..., description="Model name")
model_config = ConfigDict(extra="allow")
use_responses_api: bool | None = Field(
default=None,
description="Whether to route OpenAI ChatOpenAI calls through the /v1/responses API",
)
output_version: str | None = Field(
default=None,
description="Structured output version for OpenAI responses content, e.g. responses/v1",
)
supports_thinking: bool = Field(default_factory=lambda: False, description="Whether the model supports thinking")
supports_reasoning_effort: bool = Field(default_factory=lambda: False, description="Whether the model supports reasoning effort")
when_thinking_enabled: dict | None = Field(
default_factory=lambda: None,
description="Extra settings to be passed to the model when thinking is enabled",
)
when_thinking_disabled: dict | None = Field(
default_factory=lambda: None,
description="Extra settings to be passed to the model when thinking is disabled",
)
supports_vision: bool = Field(default_factory=lambda: False, description="Whether the model supports vision/image inputs")
thinking: dict | None = Field(
default_factory=lambda: None,
description=(
"Thinking settings for the model. If provided, these settings will be passed to the model when thinking is enabled. "
"This is a shortcut for `when_thinking_enabled` and will be merged with `when_thinking_enabled` if both are provided."
),
)

View File

@@ -0,0 +1,306 @@
import os
import re
import shutil
from pathlib import Path, PureWindowsPath
# Virtual path prefix seen by agents inside the sandbox
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
def _default_local_base_dir() -> Path:
"""Return the repo-local DeerFlow state directory without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4]
return backend_dir / ".deer-flow"
def _validate_thread_id(thread_id: str) -> str:
"""Validate a thread ID before using it in filesystem paths."""
if not _SAFE_THREAD_ID_RE.match(thread_id):
raise ValueError(f"Invalid thread_id {thread_id!r}: only alphanumeric characters, hyphens, and underscores are allowed.")
return thread_id
def _join_host_path(base: str, *parts: str) -> str:
"""Join host filesystem path segments while preserving native style.
Docker Desktop on Windows expects bind mount sources to stay in Windows
path form (for example ``C:\\repo\\backend\\.deer-flow``). Using
``Path(base) / ...`` on a POSIX host can accidentally rewrite those paths
with mixed separators, so this helper preserves the original style.
"""
if not parts:
return base
if re.match(r"^[A-Za-z]:[\\/]", base) or base.startswith("\\\\") or "\\" in base:
result = PureWindowsPath(base)
for part in parts:
result /= part
return str(result)
result = Path(base)
for part in parts:
result /= part
return str(result)
def join_host_path(base: str, *parts: str) -> str:
"""Join host filesystem path segments while preserving native style."""
return _join_host_path(base, *parts)
class Paths:
"""
Centralized path configuration for DeerFlow application data.
Directory layout (host side):
{base_dir}/
├── memory.json
├── USER.md <-- global user profile (injected into all agents)
├── agents/
│ └── {agent_name}/
│ ├── config.yaml
│ ├── SOUL.md <-- agent personality/identity (injected alongside lead prompt)
│ └── memory.json
└── threads/
└── {thread_id}/
└── user-data/ <-- mounted as /mnt/user-data/ inside sandbox
├── workspace/ <-- /mnt/user-data/workspace/
├── uploads/ <-- /mnt/user-data/uploads/
└── outputs/ <-- /mnt/user-data/outputs/
BaseDir resolution (in priority order):
1. Constructor argument `base_dir`
2. DEER_FLOW_HOME environment variable
3. Repo-local fallback derived from this module path: `{backend_dir}/.deer-flow`
"""
def __init__(self, base_dir: str | Path | None = None) -> None:
self._base_dir = Path(base_dir).resolve() if base_dir is not None else None
@property
def host_base_dir(self) -> Path:
"""Host-visible base dir for Docker volume mount sources.
When running inside Docker with a mounted Docker socket (DooD), the Docker
daemon runs on the host and resolves mount paths against the host filesystem.
Set DEER_FLOW_HOST_BASE_DIR to the host-side path that corresponds to this
container's base_dir so that sandbox container volume mounts work correctly.
Falls back to base_dir when the env var is not set (native/local execution).
"""
if env := os.getenv("DEER_FLOW_HOST_BASE_DIR"):
return Path(env)
return self.base_dir
def _host_base_dir_str(self) -> str:
"""Return the host base dir as a raw string for bind mounts."""
if env := os.getenv("DEER_FLOW_HOST_BASE_DIR"):
return env
return str(self.base_dir)
@property
def base_dir(self) -> Path:
"""Root directory for all application data."""
if self._base_dir is not None:
return self._base_dir
if env_home := os.getenv("DEER_FLOW_HOME"):
return Path(env_home).resolve()
return _default_local_base_dir()
@property
def memory_file(self) -> Path:
"""Path to the persisted memory file: `{base_dir}/memory.json`."""
return self.base_dir / "memory.json"
@property
def user_md_file(self) -> Path:
"""Path to the global user profile file: `{base_dir}/USER.md`."""
return self.base_dir / "USER.md"
@property
def agents_dir(self) -> Path:
"""Root directory for all custom agents: `{base_dir}/agents/`."""
return self.base_dir / "agents"
def agent_dir(self, name: str) -> Path:
"""Directory for a specific agent: `{base_dir}/agents/{name}/`."""
return self.agents_dir / name.lower()
def agent_memory_file(self, name: str) -> Path:
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
return self.agent_dir(name) / "memory.json"
def thread_dir(self, thread_id: str) -> Path:
"""
Host path for a thread's data: `{base_dir}/threads/{thread_id}/`
This directory contains a `user-data/` subdirectory that is mounted
as `/mnt/user-data/` inside the sandbox.
Raises:
ValueError: If `thread_id` contains unsafe characters (path separators
or `..`) that could cause directory traversal.
"""
return self.base_dir / "threads" / _validate_thread_id(thread_id)
def sandbox_work_dir(self, thread_id: str) -> Path:
"""
Host path for the agent's workspace directory.
Host: `{base_dir}/threads/{thread_id}/user-data/workspace/`
Sandbox: `/mnt/user-data/workspace/`
"""
return self.thread_dir(thread_id) / "user-data" / "workspace"
def sandbox_uploads_dir(self, thread_id: str) -> Path:
"""
Host path for user-uploaded files.
Host: `{base_dir}/threads/{thread_id}/user-data/uploads/`
Sandbox: `/mnt/user-data/uploads/`
"""
return self.thread_dir(thread_id) / "user-data" / "uploads"
def sandbox_outputs_dir(self, thread_id: str) -> Path:
"""
Host path for agent-generated artifacts.
Host: `{base_dir}/threads/{thread_id}/user-data/outputs/`
Sandbox: `/mnt/user-data/outputs/`
"""
return self.thread_dir(thread_id) / "user-data" / "outputs"
def acp_workspace_dir(self, thread_id: str) -> Path:
"""
Host path for the ACP workspace of a specific thread.
Host: `{base_dir}/threads/{thread_id}/acp-workspace/`
Sandbox: `/mnt/acp-workspace/`
Each thread gets its own isolated ACP workspace so that concurrent
sessions cannot read each other's ACP agent outputs.
"""
return self.thread_dir(thread_id) / "acp-workspace"
def sandbox_user_data_dir(self, thread_id: str) -> Path:
"""
Host path for the user-data root.
Host: `{base_dir}/threads/{thread_id}/user-data/`
Sandbox: `/mnt/user-data/`
"""
return self.thread_dir(thread_id) / "user-data"
def host_thread_dir(self, thread_id: str) -> str:
"""Host path for a thread directory, preserving Windows path syntax."""
return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id))
def host_sandbox_user_data_dir(self, thread_id: str) -> str:
"""Host path for a thread's user-data root."""
return _join_host_path(self.host_thread_dir(thread_id), "user-data")
def host_sandbox_work_dir(self, thread_id: str) -> str:
"""Host path for the workspace mount source."""
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "workspace")
def host_sandbox_uploads_dir(self, thread_id: str) -> str:
"""Host path for the uploads mount source."""
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "uploads")
def host_sandbox_outputs_dir(self, thread_id: str) -> str:
"""Host path for the outputs mount source."""
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "outputs")
def host_acp_workspace_dir(self, thread_id: str) -> str:
"""Host path for the ACP workspace mount source."""
return _join_host_path(self.host_thread_dir(thread_id), "acp-workspace")
def ensure_thread_dirs(self, thread_id: str) -> None:
"""Create all standard sandbox directories for a thread.
Directories are created with mode 0o777 so that sandbox containers
(which may run as a different UID than the host backend process) can
write to the volume-mounted paths without "Permission denied" errors.
The explicit chmod() call is necessary because Path.mkdir(mode=...) is
subject to the process umask and may not yield the intended permissions.
Includes the ACP workspace directory so it can be volume-mounted into
the sandbox container at ``/mnt/acp-workspace`` even before the first
ACP agent invocation.
"""
for d in [
self.sandbox_work_dir(thread_id),
self.sandbox_uploads_dir(thread_id),
self.sandbox_outputs_dir(thread_id),
self.acp_workspace_dir(thread_id),
]:
d.mkdir(parents=True, exist_ok=True)
d.chmod(0o777)
def delete_thread_dir(self, thread_id: str) -> None:
"""Delete all persisted data for a thread.
The operation is idempotent: missing thread directories are ignored.
"""
thread_dir = self.thread_dir(thread_id)
if thread_dir.exists():
shutil.rmtree(thread_dir)
def resolve_virtual_path(self, thread_id: str, virtual_path: str) -> Path:
"""Resolve a sandbox virtual path to the actual host filesystem path.
Args:
thread_id: The thread ID.
virtual_path: Virtual path as seen inside the sandbox, e.g.
``/mnt/user-data/outputs/report.pdf``.
Leading slashes are stripped before matching.
Returns:
The resolved absolute host filesystem path.
Raises:
ValueError: If the path does not start with the expected virtual
prefix or a path-traversal attempt is detected.
"""
stripped = virtual_path.lstrip("/")
prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
# Require an exact segment-boundary match to avoid prefix confusion
# (e.g. reject paths like "mnt/user-dataX/...").
if stripped != prefix and not stripped.startswith(prefix + "/"):
raise ValueError(f"Path must start with /{prefix}")
relative = stripped[len(prefix) :].lstrip("/")
base = self.sandbox_user_data_dir(thread_id).resolve()
actual = (base / relative).resolve()
try:
actual.relative_to(base)
except ValueError:
raise ValueError("Access denied: path traversal detected")
return actual
# ── Singleton ────────────────────────────────────────────────────────────
_paths: Paths | None = None
def get_paths() -> Paths:
"""Return the global Paths singleton (lazy-initialized)."""
global _paths
if _paths is None:
_paths = Paths()
return _paths
def resolve_path(path: str) -> Path:
"""Resolve *path* to an absolute ``Path``.
Relative paths are resolved relative to the application base directory.
Absolute paths are returned as-is (after normalisation).
"""
p = Path(path)
if not p.is_absolute():
p = get_paths().base_dir / path
return p.resolve()

View File

@@ -0,0 +1,83 @@
from pydantic import BaseModel, ConfigDict, Field
class VolumeMountConfig(BaseModel):
"""Configuration for a volume mount."""
host_path: str = Field(..., description="Path on the host machine")
container_path: str = Field(..., description="Path inside the container")
read_only: bool = Field(default=False, description="Whether the mount is read-only")
class SandboxConfig(BaseModel):
"""Config section for a sandbox.
Common options:
use: Class path of the sandbox provider (required)
allow_host_bash: Enable host-side bash execution for LocalSandboxProvider.
Dangerous and intended only for fully trusted local workflows.
AioSandboxProvider specific options:
image: Docker image to use (default: enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest)
port: Base port for sandbox containers (default: 8080)
replicas: Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.
container_prefix: Prefix for container names (default: deer-flow-sandbox)
idle_timeout: Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.
mounts: List of volume mounts to share directories with the container
environment: Environment variables to inject into the container (values starting with $ are resolved from host env)
"""
use: str = Field(
...,
description="Class path of the sandbox provider (e.g. deerflow.sandbox.local:LocalSandboxProvider)",
)
allow_host_bash: bool = Field(
default=False,
description="Allow the bash tool to execute directly on the host when using LocalSandboxProvider. Dangerous; intended only for fully trusted local environments.",
)
image: str | None = Field(
default=None,
description="Docker image to use for the sandbox container",
)
port: int | None = Field(
default=None,
description="Base port for sandbox containers",
)
replicas: int | None = Field(
default=None,
description="Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.",
)
container_prefix: str | None = Field(
default=None,
description="Prefix for container names",
)
idle_timeout: int | None = Field(
default=None,
description="Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.",
)
mounts: list[VolumeMountConfig] = Field(
default_factory=list,
description="List of volume mounts to share directories between host and container",
)
environment: dict[str, str] = Field(
default_factory=dict,
description="Environment variables to inject into the sandbox container. Values starting with $ will be resolved from host environment variables.",
)
bash_output_max_chars: int = Field(
default=20000,
ge=0,
description="Maximum characters to keep from bash tool output. Output exceeding this limit is middle-truncated (head + tail), preserving the first and last half. Set to 0 to disable truncation.",
)
read_file_output_max_chars: int = Field(
default=50000,
ge=0,
description="Maximum characters to keep from read_file tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
)
ls_output_max_chars: int = Field(
default=20000,
ge=0,
description="Maximum characters to keep from ls tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
)
model_config = ConfigDict(extra="allow")

View File

@@ -0,0 +1,14 @@
from pydantic import BaseModel, Field
class SkillEvolutionConfig(BaseModel):
"""Configuration for agent-managed skill evolution."""
enabled: bool = Field(
default=False,
description="Whether the agent can create and modify skills under skills/custom.",
)
moderation_model_name: str | None = Field(
default=None,
description="Optional model name for skill security moderation. Defaults to the primary chat model.",
)

View File

@@ -0,0 +1,54 @@
from pathlib import Path
from pydantic import BaseModel, Field
def _default_repo_root() -> Path:
"""Resolve the repo root without relying on the current working directory."""
return Path(__file__).resolve().parents[5]
class SkillsConfig(BaseModel):
"""Configuration for skills system"""
path: str | None = Field(
default=None,
description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory",
)
container_path: str = Field(
default="/mnt/skills",
description="Path where skills are mounted in the sandbox container",
)
def get_skills_path(self) -> Path:
"""
Get the resolved skills directory path.
Returns:
Path to the skills directory
"""
if self.path:
# Use configured path (can be absolute or relative)
path = Path(self.path)
if not path.is_absolute():
# If relative, resolve from the repo root for deterministic behavior.
path = _default_repo_root() / path
return path.resolve()
else:
# Default: ../skills relative to backend directory
from deerflow.skills.loader import get_skills_root_path
return get_skills_root_path()
def get_skill_container_path(self, skill_name: str, category: str = "public") -> str:
"""
Get the full container path for a specific skill.
Args:
skill_name: Name of the skill (directory name)
category: Category of the skill (public or custom)
Returns:
Full path to the skill in the container
"""
return f"{self.container_path}/{category}/{skill_name}"

View File

@@ -0,0 +1,46 @@
"""Configuration for stream bridge."""
from typing import Literal
from pydantic import BaseModel, Field
StreamBridgeType = Literal["memory", "redis"]
class StreamBridgeConfig(BaseModel):
"""Configuration for the stream bridge that connects agent workers to SSE endpoints."""
type: StreamBridgeType = Field(
default="memory",
description="Stream bridge backend type. 'memory' uses in-process asyncio.Queue (single-process only). 'redis' uses Redis Streams (planned for Phase 2, not yet implemented).",
)
redis_url: str | None = Field(
default=None,
description="Redis URL for the redis stream bridge type. Example: 'redis://localhost:6379/0'.",
)
queue_maxsize: int = Field(
default=256,
description="Maximum number of events buffered per run in the memory bridge.",
)
# Global configuration instance — None means no stream bridge is configured
# (falls back to memory with defaults).
_stream_bridge_config: StreamBridgeConfig | None = None
def get_stream_bridge_config() -> StreamBridgeConfig | None:
"""Get the current stream bridge configuration, or None if not configured."""
return _stream_bridge_config
def set_stream_bridge_config(config: StreamBridgeConfig | None) -> None:
"""Set the stream bridge configuration."""
global _stream_bridge_config
_stream_bridge_config = config
def load_stream_bridge_config_from_dict(config_dict: dict) -> None:
"""Load stream bridge configuration from a dictionary."""
global _stream_bridge_config
_stream_bridge_config = StreamBridgeConfig(**config_dict)

View File

@@ -0,0 +1,102 @@
"""Configuration for the subagent system loaded from config.yaml."""
import logging
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class SubagentOverrideConfig(BaseModel):
"""Per-agent configuration overrides."""
timeout_seconds: int | None = Field(
default=None,
ge=1,
description="Timeout in seconds for this subagent (None = use global default)",
)
max_turns: int | None = Field(
default=None,
ge=1,
description="Maximum turns for this subagent (None = use global or builtin default)",
)
class SubagentsAppConfig(BaseModel):
"""Configuration for the subagent system."""
timeout_seconds: int = Field(
default=900,
ge=1,
description="Default timeout in seconds for all subagents (default: 900 = 15 minutes)",
)
max_turns: int | None = Field(
default=None,
ge=1,
description="Optional default max-turn override for all subagents (None = keep builtin defaults)",
)
agents: dict[str, SubagentOverrideConfig] = Field(
default_factory=dict,
description="Per-agent configuration overrides keyed by agent name",
)
def get_timeout_for(self, agent_name: str) -> int:
"""Get the effective timeout for a specific agent.
Args:
agent_name: The name of the subagent.
Returns:
The timeout in seconds, using per-agent override if set, otherwise global default.
"""
override = self.agents.get(agent_name)
if override is not None and override.timeout_seconds is not None:
return override.timeout_seconds
return self.timeout_seconds
def get_max_turns_for(self, agent_name: str, builtin_default: int) -> int:
"""Get the effective max_turns for a specific agent."""
override = self.agents.get(agent_name)
if override is not None and override.max_turns is not None:
return override.max_turns
if self.max_turns is not None:
return self.max_turns
return builtin_default
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
def get_subagents_app_config() -> SubagentsAppConfig:
"""Get the current subagents configuration."""
return _subagents_config
def load_subagents_config_from_dict(config_dict: dict) -> None:
"""Load subagents configuration from a dictionary."""
global _subagents_config
_subagents_config = SubagentsAppConfig(**config_dict)
overrides_summary = {}
for name, override in _subagents_config.agents.items():
parts = []
if override.timeout_seconds is not None:
parts.append(f"timeout={override.timeout_seconds}s")
if override.max_turns is not None:
parts.append(f"max_turns={override.max_turns}")
if parts:
overrides_summary[name] = ", ".join(parts)
if overrides_summary:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
overrides_summary,
)
else:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
)

View File

@@ -0,0 +1,74 @@
"""Configuration for conversation summarization."""
from typing import Literal
from pydantic import BaseModel, Field
ContextSizeType = Literal["fraction", "tokens", "messages"]
class ContextSize(BaseModel):
"""Context size specification for trigger or keep parameters."""
type: ContextSizeType = Field(description="Type of context size specification")
value: int | float = Field(description="Value for the context size specification")
def to_tuple(self) -> tuple[ContextSizeType, int | float]:
"""Convert to tuple format expected by SummarizationMiddleware."""
return (self.type, self.value)
class SummarizationConfig(BaseModel):
"""Configuration for automatic conversation summarization."""
enabled: bool = Field(
default=False,
description="Whether to enable automatic conversation summarization",
)
model_name: str | None = Field(
default=None,
description="Model name to use for summarization (None = use a lightweight model)",
)
trigger: ContextSize | list[ContextSize] | None = Field(
default=None,
description="One or more thresholds that trigger summarization. When any threshold is met, summarization runs. "
"Examples: {'type': 'messages', 'value': 50} triggers at 50 messages, "
"{'type': 'tokens', 'value': 4000} triggers at 4000 tokens, "
"{'type': 'fraction', 'value': 0.8} triggers at 80% of model's max input tokens",
)
keep: ContextSize = Field(
default_factory=lambda: ContextSize(type="messages", value=20),
description="Context retention policy after summarization. Specifies how much history to preserve. "
"Examples: {'type': 'messages', 'value': 20} keeps 20 messages, "
"{'type': 'tokens', 'value': 3000} keeps 3000 tokens, "
"{'type': 'fraction', 'value': 0.3} keeps 30% of model's max input tokens",
)
trim_tokens_to_summarize: int | None = Field(
default=4000,
description="Maximum tokens to keep when preparing messages for summarization. Pass null to skip trimming.",
)
summary_prompt: str | None = Field(
default=None,
description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.",
)
# Global configuration instance
_summarization_config: SummarizationConfig = SummarizationConfig()
def get_summarization_config() -> SummarizationConfig:
"""Get the current summarization configuration."""
return _summarization_config
def set_summarization_config(config: SummarizationConfig) -> None:
"""Set the summarization configuration."""
global _summarization_config
_summarization_config = config
def load_summarization_config_from_dict(config_dict: dict) -> None:
"""Load summarization configuration from a dictionary."""
global _summarization_config
_summarization_config = SummarizationConfig(**config_dict)

View File

@@ -0,0 +1,53 @@
"""Configuration for automatic thread title generation."""
from pydantic import BaseModel, Field
class TitleConfig(BaseModel):
"""Configuration for automatic thread title generation."""
enabled: bool = Field(
default=True,
description="Whether to enable automatic title generation",
)
max_words: int = Field(
default=6,
ge=1,
le=20,
description="Maximum number of words in the generated title",
)
max_chars: int = Field(
default=60,
ge=10,
le=200,
description="Maximum number of characters in the generated title",
)
model_name: str | None = Field(
default=None,
description="Model name to use for title generation (None = use default model)",
)
prompt_template: str = Field(
default=("Generate a concise title (max {max_words} words) for this conversation.\nUser: {user_msg}\nAssistant: {assistant_msg}\n\nReturn ONLY the title, no quotes, no explanation."),
description="Prompt template for title generation",
)
# Global configuration instance
_title_config: TitleConfig = TitleConfig()
def get_title_config() -> TitleConfig:
"""Get the current title configuration."""
return _title_config
def set_title_config(config: TitleConfig) -> None:
"""Set the title configuration."""
global _title_config
_title_config = config
def load_title_config_from_dict(config_dict: dict) -> None:
"""Load title configuration from a dictionary."""
global _title_config
_title_config = TitleConfig(**config_dict)

View File

@@ -0,0 +1,7 @@
from pydantic import BaseModel, Field
class TokenUsageConfig(BaseModel):
"""Configuration for token usage tracking."""
enabled: bool = Field(default=False, description="Enable token usage tracking middleware")

View File

@@ -0,0 +1,20 @@
from pydantic import BaseModel, ConfigDict, Field
class ToolGroupConfig(BaseModel):
"""Config section for a tool group"""
name: str = Field(..., description="Unique name for the tool group")
model_config = ConfigDict(extra="allow")
class ToolConfig(BaseModel):
"""Config section for a tool"""
name: str = Field(..., description="Unique name for the tool")
group: str = Field(..., description="Group name for the tool")
use: str = Field(
...,
description="Variable name of the tool provider(e.g. deerflow.sandbox.tools:bash_tool)",
)
model_config = ConfigDict(extra="allow")

View File

@@ -0,0 +1,35 @@
"""Configuration for deferred tool loading via tool_search."""
from pydantic import BaseModel, Field
class ToolSearchConfig(BaseModel):
"""Configuration for deferred tool loading via tool_search.
When enabled, MCP tools are not loaded into the agent's context directly.
Instead, they are listed by name in the system prompt and discoverable
via the tool_search tool at runtime.
"""
enabled: bool = Field(
default=False,
description="Defer tools and enable tool_search",
)
_tool_search_config: ToolSearchConfig | None = None
def get_tool_search_config() -> ToolSearchConfig:
"""Get the tool search config, loading from AppConfig if needed."""
global _tool_search_config
if _tool_search_config is None:
_tool_search_config = ToolSearchConfig()
return _tool_search_config
def load_tool_search_config_from_dict(data: dict) -> ToolSearchConfig:
"""Load tool search config from a dict (called during AppConfig loading)."""
global _tool_search_config
_tool_search_config = ToolSearchConfig.model_validate(data)
return _tool_search_config

View File

@@ -0,0 +1,149 @@
import os
import threading
from pydantic import BaseModel, Field
_config_lock = threading.Lock()
class LangSmithTracingConfig(BaseModel):
"""Configuration for LangSmith tracing."""
enabled: bool = Field(...)
api_key: str | None = Field(...)
project: str = Field(...)
endpoint: str = Field(...)
@property
def is_configured(self) -> bool:
return self.enabled and bool(self.api_key)
def validate(self) -> None:
if self.enabled and not self.api_key:
raise ValueError("LangSmith tracing is enabled but LANGSMITH_API_KEY (or LANGCHAIN_API_KEY) is not set.")
class LangfuseTracingConfig(BaseModel):
"""Configuration for Langfuse tracing."""
enabled: bool = Field(...)
public_key: str | None = Field(...)
secret_key: str | None = Field(...)
host: str = Field(...)
@property
def is_configured(self) -> bool:
return self.enabled and bool(self.public_key) and bool(self.secret_key)
def validate(self) -> None:
if not self.enabled:
return
missing: list[str] = []
if not self.public_key:
missing.append("LANGFUSE_PUBLIC_KEY")
if not self.secret_key:
missing.append("LANGFUSE_SECRET_KEY")
if missing:
raise ValueError(f"Langfuse tracing is enabled but required settings are missing: {', '.join(missing)}")
class TracingConfig(BaseModel):
"""Tracing configuration for supported providers."""
langsmith: LangSmithTracingConfig = Field(...)
langfuse: LangfuseTracingConfig = Field(...)
@property
def is_configured(self) -> bool:
return bool(self.enabled_providers)
@property
def explicitly_enabled_providers(self) -> list[str]:
enabled: list[str] = []
if self.langsmith.enabled:
enabled.append("langsmith")
if self.langfuse.enabled:
enabled.append("langfuse")
return enabled
@property
def enabled_providers(self) -> list[str]:
enabled: list[str] = []
if self.langsmith.is_configured:
enabled.append("langsmith")
if self.langfuse.is_configured:
enabled.append("langfuse")
return enabled
def validate_enabled(self) -> None:
self.langsmith.validate()
self.langfuse.validate()
_tracing_config: TracingConfig | None = None
_TRUTHY_VALUES = {"1", "true", "yes", "on"}
def _env_flag_preferred(*names: str) -> bool:
"""Return the boolean value of the first env var that is present and non-empty."""
for name in names:
value = os.environ.get(name)
if value is not None and value.strip():
return value.strip().lower() in _TRUTHY_VALUES
return False
def _first_env_value(*names: str) -> str | None:
"""Return the first non-empty environment value from candidate names."""
for name in names:
value = os.environ.get(name)
if value and value.strip():
return value.strip()
return None
def get_tracing_config() -> TracingConfig:
"""Get the current tracing configuration from environment variables."""
global _tracing_config
if _tracing_config is not None:
return _tracing_config
with _config_lock:
if _tracing_config is not None:
return _tracing_config
_tracing_config = TracingConfig(
langsmith=LangSmithTracingConfig(
enabled=_env_flag_preferred("LANGSMITH_TRACING", "LANGCHAIN_TRACING_V2", "LANGCHAIN_TRACING"),
api_key=_first_env_value("LANGSMITH_API_KEY", "LANGCHAIN_API_KEY"),
project=_first_env_value("LANGSMITH_PROJECT", "LANGCHAIN_PROJECT") or "deer-flow",
endpoint=_first_env_value("LANGSMITH_ENDPOINT", "LANGCHAIN_ENDPOINT") or "https://api.smith.langchain.com",
),
langfuse=LangfuseTracingConfig(
enabled=_env_flag_preferred("LANGFUSE_TRACING"),
public_key=_first_env_value("LANGFUSE_PUBLIC_KEY"),
secret_key=_first_env_value("LANGFUSE_SECRET_KEY"),
host=_first_env_value("LANGFUSE_BASE_URL") or "https://cloud.langfuse.com",
),
)
return _tracing_config
def get_enabled_tracing_providers() -> list[str]:
"""Return the configured tracing providers that are enabled and complete."""
return get_tracing_config().enabled_providers
def get_explicitly_enabled_tracing_providers() -> list[str]:
"""Return tracing providers explicitly enabled by config, even if incomplete."""
return get_tracing_config().explicitly_enabled_providers
def validate_enabled_tracing_providers() -> None:
"""Validate that any explicitly enabled providers are fully configured."""
get_tracing_config().validate_enabled()
def is_tracing_enabled() -> bool:
"""Check if any tracing provider is enabled and fully configured."""
return get_tracing_config().is_configured

View File

@@ -0,0 +1,14 @@
"""Pre-tool-call authorization middleware."""
from deerflow.guardrails.builtin import AllowlistProvider
from deerflow.guardrails.middleware import GuardrailMiddleware
from deerflow.guardrails.provider import GuardrailDecision, GuardrailProvider, GuardrailReason, GuardrailRequest
__all__ = [
"AllowlistProvider",
"GuardrailDecision",
"GuardrailMiddleware",
"GuardrailProvider",
"GuardrailReason",
"GuardrailRequest",
]

View File

@@ -0,0 +1,23 @@
"""Built-in guardrail providers that ship with DeerFlow."""
from deerflow.guardrails.provider import GuardrailDecision, GuardrailReason, GuardrailRequest
class AllowlistProvider:
"""Simple allowlist/denylist provider. No external dependencies."""
name = "allowlist"
def __init__(self, *, allowed_tools: list[str] | None = None, denied_tools: list[str] | None = None):
self._allowed = set(allowed_tools) if allowed_tools else None
self._denied = set(denied_tools) if denied_tools else set()
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
if self._allowed is not None and request.tool_name not in self._allowed:
return GuardrailDecision(allow=False, reasons=[GuardrailReason(code="oap.tool_not_allowed", message=f"tool '{request.tool_name}' not in allowlist")])
if request.tool_name in self._denied:
return GuardrailDecision(allow=False, reasons=[GuardrailReason(code="oap.tool_not_allowed", message=f"tool '{request.tool_name}' is denied")])
return GuardrailDecision(allow=True, reasons=[GuardrailReason(code="oap.allowed")])
async def aevaluate(self, request: GuardrailRequest) -> GuardrailDecision:
return self.evaluate(request)

View File

@@ -0,0 +1,98 @@
"""GuardrailMiddleware - evaluates tool calls against a GuardrailProvider before execution."""
import logging
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.errors import GraphBubbleUp
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from deerflow.guardrails.provider import GuardrailDecision, GuardrailProvider, GuardrailReason, GuardrailRequest
logger = logging.getLogger(__name__)
class GuardrailMiddleware(AgentMiddleware[AgentState]):
"""Evaluate tool calls against a GuardrailProvider before execution.
Denied calls return an error ToolMessage so the agent can adapt.
If the provider raises, behavior depends on fail_closed:
- True (default): block the call
- False: allow it through with a warning
"""
def __init__(self, provider: GuardrailProvider, *, fail_closed: bool = True, passport: str | None = None):
self.provider = provider
self.fail_closed = fail_closed
self.passport = passport
def _build_request(self, request: ToolCallRequest) -> GuardrailRequest:
return GuardrailRequest(
tool_name=str(request.tool_call.get("name", "")),
tool_input=request.tool_call.get("args", {}),
agent_id=self.passport,
timestamp=datetime.now(UTC).isoformat(),
)
def _build_denied_message(self, request: ToolCallRequest, decision: GuardrailDecision) -> ToolMessage:
tool_name = str(request.tool_call.get("name", "unknown_tool"))
tool_call_id = str(request.tool_call.get("id", "missing_id"))
reason_text = decision.reasons[0].message if decision.reasons else "blocked by guardrail policy"
reason_code = decision.reasons[0].code if decision.reasons else "oap.denied"
return ToolMessage(
content=f"Guardrail denied: tool '{tool_name}' was blocked ({reason_code}). Reason: {reason_text}. Choose an alternative approach.",
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
gr = self._build_request(request)
try:
decision = self.provider.evaluate(gr)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception:
logger.exception("Guardrail provider error (sync)")
if self.fail_closed:
decision = GuardrailDecision(allow=False, reasons=[GuardrailReason(code="oap.evaluator_error", message="guardrail provider error (fail-closed)")])
else:
return handler(request)
if not decision.allow:
logger.warning("Guardrail denied: tool=%s policy=%s code=%s", gr.tool_name, decision.policy_id, decision.reasons[0].code if decision.reasons else "unknown")
return self._build_denied_message(request, decision)
return handler(request)
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
gr = self._build_request(request)
try:
decision = await self.provider.aevaluate(gr)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception:
logger.exception("Guardrail provider error (async)")
if self.fail_closed:
decision = GuardrailDecision(allow=False, reasons=[GuardrailReason(code="oap.evaluator_error", message="guardrail provider error (fail-closed)")])
else:
return await handler(request)
if not decision.allow:
logger.warning("Guardrail denied: tool=%s policy=%s code=%s", gr.tool_name, decision.policy_id, decision.reasons[0].code if decision.reasons else "unknown")
return self._build_denied_message(request, decision)
return await handler(request)

View File

@@ -0,0 +1,56 @@
"""GuardrailProvider protocol and data structures for pre-tool-call authorization."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Protocol, runtime_checkable
@dataclass
class GuardrailRequest:
"""Context passed to the provider for each tool call."""
tool_name: str
tool_input: dict[str, Any]
agent_id: str | None = None
thread_id: str | None = None
is_subagent: bool = False
timestamp: str = ""
@dataclass
class GuardrailReason:
"""Structured reason for an allow/deny decision (OAP reason object)."""
code: str
message: str = ""
@dataclass
class GuardrailDecision:
"""Provider's allow/deny verdict (aligned with OAP Decision object)."""
allow: bool
reasons: list[GuardrailReason] = field(default_factory=list)
policy_id: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@runtime_checkable
class GuardrailProvider(Protocol):
"""Contract for pluggable tool-call authorization.
Any class with these methods works - no base class required.
Providers are loaded by class path via resolve_variable(),
the same mechanism DeerFlow uses for models, tools, and sandbox.
"""
name: str
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
"""Evaluate whether a tool call should proceed."""
...
async def aevaluate(self, request: GuardrailRequest) -> GuardrailDecision:
"""Async variant."""
...

View File

@@ -0,0 +1,14 @@
"""MCP (Model Context Protocol) integration using langchain-mcp-adapters."""
from .cache import get_cached_mcp_tools, initialize_mcp_tools, reset_mcp_tools_cache
from .client import build_server_params, build_servers_config
from .tools import get_mcp_tools
__all__ = [
"build_server_params",
"build_servers_config",
"get_mcp_tools",
"initialize_mcp_tools",
"get_cached_mcp_tools",
"reset_mcp_tools_cache",
]

View File

@@ -0,0 +1,138 @@
"""Cache for MCP tools to avoid repeated loading."""
import asyncio
import logging
import os
from langchain_core.tools import BaseTool
logger = logging.getLogger(__name__)
_mcp_tools_cache: list[BaseTool] | None = None
_cache_initialized = False
_initialization_lock = asyncio.Lock()
_config_mtime: float | None = None # Track config file modification time
def _get_config_mtime() -> float | None:
"""Get the modification time of the extensions config file.
Returns:
The modification time as a float, or None if the file doesn't exist.
"""
from deerflow.config.extensions_config import ExtensionsConfig
config_path = ExtensionsConfig.resolve_config_path()
if config_path and config_path.exists():
return os.path.getmtime(config_path)
return None
def _is_cache_stale() -> bool:
"""Check if the cache is stale due to config file changes.
Returns:
True if the cache should be invalidated, False otherwise.
"""
global _config_mtime
if not _cache_initialized:
return False # Not initialized yet, not stale
current_mtime = _get_config_mtime()
# If we couldn't get mtime before or now, assume not stale
if _config_mtime is None or current_mtime is None:
return False
# If the config file has been modified since we cached, it's stale
if current_mtime > _config_mtime:
logger.info(f"MCP config file has been modified (mtime: {_config_mtime} -> {current_mtime}), cache is stale")
return True
return False
async def initialize_mcp_tools() -> list[BaseTool]:
"""Initialize and cache MCP tools.
This should be called once at application startup.
Returns:
List of LangChain tools from all enabled MCP servers.
"""
global _mcp_tools_cache, _cache_initialized, _config_mtime
async with _initialization_lock:
if _cache_initialized:
logger.info("MCP tools already initialized")
return _mcp_tools_cache or []
from deerflow.mcp.tools import get_mcp_tools
logger.info("Initializing MCP tools...")
_mcp_tools_cache = await get_mcp_tools()
_cache_initialized = True
_config_mtime = _get_config_mtime() # Record config file mtime
logger.info(f"MCP tools initialized: {len(_mcp_tools_cache)} tool(s) loaded (config mtime: {_config_mtime})")
return _mcp_tools_cache
def get_cached_mcp_tools() -> list[BaseTool]:
"""Get cached MCP tools with lazy initialization.
If tools are not initialized, automatically initializes them.
This ensures MCP tools work in both FastAPI and LangGraph Studio contexts.
Also checks if the config file has been modified since last initialization,
and re-initializes if needed. This ensures that changes made through the
Gateway API (which runs in a separate process) are reflected in the
LangGraph Server.
Returns:
List of cached MCP tools.
"""
global _cache_initialized
# Check if cache is stale due to config file changes
if _is_cache_stale():
logger.info("MCP cache is stale, resetting for re-initialization...")
reset_mcp_tools_cache()
if not _cache_initialized:
logger.info("MCP tools not initialized, performing lazy initialization...")
try:
# Try to initialize in the current event loop
loop = asyncio.get_event_loop()
if loop.is_running():
# If loop is already running (e.g., in LangGraph Studio),
# we need to create a new loop in a thread
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, initialize_mcp_tools())
future.result()
else:
# If no loop is running, we can use the current loop
loop.run_until_complete(initialize_mcp_tools())
except RuntimeError:
# No event loop exists, create one
asyncio.run(initialize_mcp_tools())
except Exception as e:
logger.error(f"Failed to lazy-initialize MCP tools: {e}")
return []
return _mcp_tools_cache or []
def reset_mcp_tools_cache() -> None:
"""Reset the MCP tools cache.
This is useful for testing or when you want to reload MCP tools.
"""
global _mcp_tools_cache, _cache_initialized, _config_mtime
_mcp_tools_cache = None
_cache_initialized = False
_config_mtime = None
logger.info("MCP tools cache reset")

View File

@@ -0,0 +1,68 @@
"""MCP client using langchain-mcp-adapters."""
import logging
from typing import Any
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
logger = logging.getLogger(__name__)
def build_server_params(server_name: str, config: McpServerConfig) -> dict[str, Any]:
"""Build server parameters for MultiServerMCPClient.
Args:
server_name: Name of the MCP server.
config: Configuration for the MCP server.
Returns:
Dictionary of server parameters for langchain-mcp-adapters.
"""
transport_type = config.type or "stdio"
params: dict[str, Any] = {"transport": transport_type}
if transport_type == "stdio":
if not config.command:
raise ValueError(f"MCP server '{server_name}' with stdio transport requires 'command' field")
params["command"] = config.command
params["args"] = config.args
# Add environment variables if present
if config.env:
params["env"] = config.env
elif transport_type in ("sse", "http"):
if not config.url:
raise ValueError(f"MCP server '{server_name}' with {transport_type} transport requires 'url' field")
params["url"] = config.url
# Add headers if present
if config.headers:
params["headers"] = config.headers
else:
raise ValueError(f"MCP server '{server_name}' has unsupported transport type: {transport_type}")
return params
def build_servers_config(extensions_config: ExtensionsConfig) -> dict[str, dict[str, Any]]:
"""Build servers configuration for MultiServerMCPClient.
Args:
extensions_config: Extensions configuration containing all MCP servers.
Returns:
Dictionary mapping server names to their parameters.
"""
enabled_servers = extensions_config.get_enabled_mcp_servers()
if not enabled_servers:
logger.info("No enabled MCP servers found")
return {}
servers_config = {}
for server_name, server_config in enabled_servers.items():
try:
servers_config[server_name] = build_server_params(server_name, server_config)
logger.info(f"Configured MCP server: {server_name}")
except Exception as e:
logger.error(f"Failed to configure MCP server '{server_name}': {e}")
return servers_config

View File

@@ -0,0 +1,150 @@
"""OAuth token support for MCP HTTP/SSE servers."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
from deerflow.config.extensions_config import ExtensionsConfig, McpOAuthConfig
logger = logging.getLogger(__name__)
@dataclass
class _OAuthToken:
"""Cached OAuth token."""
access_token: str
token_type: str
expires_at: datetime
class OAuthTokenManager:
"""Acquire/cache/refresh OAuth tokens for MCP servers."""
def __init__(self, oauth_by_server: dict[str, McpOAuthConfig]):
self._oauth_by_server = oauth_by_server
self._tokens: dict[str, _OAuthToken] = {}
self._locks: dict[str, asyncio.Lock] = {name: asyncio.Lock() for name in oauth_by_server}
@classmethod
def from_extensions_config(cls, extensions_config: ExtensionsConfig) -> OAuthTokenManager:
oauth_by_server: dict[str, McpOAuthConfig] = {}
for server_name, server_config in extensions_config.get_enabled_mcp_servers().items():
if server_config.oauth and server_config.oauth.enabled:
oauth_by_server[server_name] = server_config.oauth
return cls(oauth_by_server)
def has_oauth_servers(self) -> bool:
return bool(self._oauth_by_server)
def oauth_server_names(self) -> list[str]:
return list(self._oauth_by_server.keys())
async def get_authorization_header(self, server_name: str) -> str | None:
oauth = self._oauth_by_server.get(server_name)
if not oauth:
return None
token = self._tokens.get(server_name)
if token and not self._is_expiring(token, oauth):
return f"{token.token_type} {token.access_token}"
lock = self._locks[server_name]
async with lock:
token = self._tokens.get(server_name)
if token and not self._is_expiring(token, oauth):
return f"{token.token_type} {token.access_token}"
fresh = await self._fetch_token(oauth)
self._tokens[server_name] = fresh
logger.info(f"Refreshed OAuth access token for MCP server: {server_name}")
return f"{fresh.token_type} {fresh.access_token}"
@staticmethod
def _is_expiring(token: _OAuthToken, oauth: McpOAuthConfig) -> bool:
now = datetime.now(UTC)
return token.expires_at <= now + timedelta(seconds=max(oauth.refresh_skew_seconds, 0))
async def _fetch_token(self, oauth: McpOAuthConfig) -> _OAuthToken:
import httpx # pyright: ignore[reportMissingImports]
data: dict[str, str] = {
"grant_type": oauth.grant_type,
**oauth.extra_token_params,
}
if oauth.scope:
data["scope"] = oauth.scope
if oauth.audience:
data["audience"] = oauth.audience
if oauth.grant_type == "client_credentials":
if not oauth.client_id or not oauth.client_secret:
raise ValueError("OAuth client_credentials requires client_id and client_secret")
data["client_id"] = oauth.client_id
data["client_secret"] = oauth.client_secret
elif oauth.grant_type == "refresh_token":
if not oauth.refresh_token:
raise ValueError("OAuth refresh_token grant requires refresh_token")
data["refresh_token"] = oauth.refresh_token
if oauth.client_id:
data["client_id"] = oauth.client_id
if oauth.client_secret:
data["client_secret"] = oauth.client_secret
else:
raise ValueError(f"Unsupported OAuth grant type: {oauth.grant_type}")
async with httpx.AsyncClient(timeout=15.0) as client:
response = await client.post(oauth.token_url, data=data)
response.raise_for_status()
payload = response.json()
access_token = payload.get(oauth.token_field)
if not access_token:
raise ValueError(f"OAuth token response missing '{oauth.token_field}'")
token_type = str(payload.get(oauth.token_type_field, oauth.default_token_type) or oauth.default_token_type)
expires_in_raw = payload.get(oauth.expires_in_field, 3600)
try:
expires_in = int(expires_in_raw)
except (TypeError, ValueError):
expires_in = 3600
expires_at = datetime.now(UTC) + timedelta(seconds=max(expires_in, 1))
return _OAuthToken(access_token=access_token, token_type=token_type, expires_at=expires_at)
def build_oauth_tool_interceptor(extensions_config: ExtensionsConfig) -> Any | None:
"""Build a tool interceptor that injects OAuth Authorization headers."""
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
if not token_manager.has_oauth_servers():
return None
async def oauth_interceptor(request: Any, handler: Any) -> Any:
header = await token_manager.get_authorization_header(request.server_name)
if not header:
return await handler(request)
updated_headers = dict(request.headers or {})
updated_headers["Authorization"] = header
return await handler(request.override(headers=updated_headers))
return oauth_interceptor
async def get_initial_oauth_headers(extensions_config: ExtensionsConfig) -> dict[str, str]:
"""Get initial OAuth Authorization headers for MCP server connections."""
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
if not token_manager.has_oauth_servers():
return {}
headers: dict[str, str] = {}
for server_name in token_manager.oauth_server_names():
headers[server_name] = await token_manager.get_authorization_header(server_name) or ""
return {name: value for name, value in headers.items() if value}

View File

@@ -0,0 +1,113 @@
"""Load MCP tools using langchain-mcp-adapters."""
import asyncio
import atexit
import concurrent.futures
import logging
from collections.abc import Callable
from typing import Any
from langchain_core.tools import BaseTool
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.mcp.client import build_servers_config
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
logger = logging.getLogger(__name__)
# Global thread pool for sync tool invocation in async environments
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="mcp-sync-tool")
# Register shutdown hook for the global executor
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
def _make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
"""Build a synchronous wrapper for an asynchronous tool coroutine.
Args:
coro: The tool's asynchronous coroutine.
tool_name: Name of the tool (for logging).
Returns:
A synchronous function that correctly handles nested event loops.
"""
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
try:
if loop is not None and loop.is_running():
# Use global executor to avoid nested loop issues and improve performance
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
return future.result()
else:
return asyncio.run(coro(*args, **kwargs))
except Exception as e:
logger.error(f"Error invoking MCP tool '{tool_name}' via sync wrapper: {e}", exc_info=True)
raise
return sync_wrapper
async def get_mcp_tools() -> list[BaseTool]:
"""Get all tools from enabled MCP servers.
Returns:
List of LangChain tools from all enabled MCP servers.
"""
try:
from langchain_mcp_adapters.client import MultiServerMCPClient
except ImportError:
logger.warning("langchain-mcp-adapters not installed. Install it to enable MCP tools: pip install langchain-mcp-adapters")
return []
# NOTE: We use ExtensionsConfig.from_file() instead of get_extensions_config()
# to always read the latest configuration from disk. This ensures that changes
# made through the Gateway API (which runs in a separate process) are immediately
# reflected when initializing MCP tools.
extensions_config = ExtensionsConfig.from_file()
servers_config = build_servers_config(extensions_config)
if not servers_config:
logger.info("No enabled MCP servers configured")
return []
try:
# Create the multi-server MCP client
logger.info(f"Initializing MCP client with {len(servers_config)} server(s)")
# Inject initial OAuth headers for server connections (tool discovery/session init)
initial_oauth_headers = await get_initial_oauth_headers(extensions_config)
for server_name, auth_header in initial_oauth_headers.items():
if server_name not in servers_config:
continue
if servers_config[server_name].get("transport") in ("sse", "http"):
existing_headers = dict(servers_config[server_name].get("headers", {}))
existing_headers["Authorization"] = auth_header
servers_config[server_name]["headers"] = existing_headers
tool_interceptors = []
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
if oauth_interceptor is not None:
tool_interceptors.append(oauth_interceptor)
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
# Get all tools from all servers
tools = await client.get_tools()
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
# Patch tools to support sync invocation, as deerflow client streams synchronously
for tool in tools:
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
tool.func = _make_sync_tool_wrapper(tool.coroutine, tool.name)
return tools
except Exception as e:
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)
return []

View File

@@ -0,0 +1,3 @@
from .factory import create_chat_model
__all__ = ["create_chat_model"]

View File

@@ -0,0 +1,348 @@
"""Custom Claude provider with OAuth Bearer auth, prompt caching, and smart thinking.
Supports two authentication modes:
1. Standard API key (x-api-key header) — default ChatAnthropic behavior
2. Claude Code OAuth token (Authorization: Bearer header)
- Detected by sk-ant-oat prefix
- Requires anthropic-beta: oauth-2025-04-20,claude-code-20250219
- Requires billing header in system prompt for all OAuth requests
Auto-loads credentials from explicit runtime handoff:
- $ANTHROPIC_API_KEY environment variable
- $CLAUDE_CODE_OAUTH_TOKEN or $ANTHROPIC_AUTH_TOKEN
- $CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR
- $CLAUDE_CODE_CREDENTIALS_PATH
- ~/.claude/.credentials.json
"""
import hashlib
import json
import logging
import os
import socket
import time
import uuid
from typing import Any
import anthropic
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import BaseMessage
from pydantic import PrivateAttr
logger = logging.getLogger(__name__)
MAX_RETRIES = 3
THINKING_BUDGET_RATIO = 0.8
# Billing header required by Anthropic API for OAuth token access.
# Must be the first system prompt block. Format mirrors Claude Code CLI.
# Override with ANTHROPIC_BILLING_HEADER env var if the hardcoded version drifts.
_DEFAULT_BILLING_HEADER = "x-anthropic-billing-header: cc_version=2.1.85.351; cc_entrypoint=cli; cch=6c6d5;"
OAUTH_BILLING_HEADER = os.environ.get("ANTHROPIC_BILLING_HEADER", _DEFAULT_BILLING_HEADER)
class ClaudeChatModel(ChatAnthropic):
"""ChatAnthropic with OAuth Bearer auth, prompt caching, and smart thinking.
Config example:
- name: claude-sonnet-4.6
use: deerflow.models.claude_provider:ClaudeChatModel
model: claude-sonnet-4-6
max_tokens: 16384
enable_prompt_caching: true
"""
# Custom fields
enable_prompt_caching: bool = True
prompt_cache_size: int = 3
auto_thinking_budget: bool = True
retry_max_attempts: int = MAX_RETRIES
_is_oauth: bool = PrivateAttr(default=False)
_oauth_access_token: str = PrivateAttr(default="")
model_config = {"arbitrary_types_allowed": True}
def _validate_retry_config(self) -> None:
if self.retry_max_attempts < 1:
raise ValueError("retry_max_attempts must be >= 1")
def model_post_init(self, __context: Any) -> None:
"""Auto-load credentials and configure OAuth if needed."""
from pydantic import SecretStr
from deerflow.models.credential_loader import (
OAUTH_ANTHROPIC_BETAS,
is_oauth_token,
load_claude_code_credential,
)
self._validate_retry_config()
# Extract actual key value (SecretStr.str() returns '**********')
current_key = ""
if self.anthropic_api_key:
if hasattr(self.anthropic_api_key, "get_secret_value"):
current_key = self.anthropic_api_key.get_secret_value()
else:
current_key = str(self.anthropic_api_key)
# Try the explicit Claude Code OAuth handoff sources if no valid key.
if not current_key or current_key in ("your-anthropic-api-key",):
cred = load_claude_code_credential()
if cred:
current_key = cred.access_token
logger.info(f"Using Claude Code CLI credential (source: {cred.source})")
else:
logger.warning("No Anthropic API key or explicit Claude Code OAuth credential found.")
# Detect OAuth token and configure Bearer auth
if is_oauth_token(current_key):
self._is_oauth = True
self._oauth_access_token = current_key
# Set the token as api_key temporarily (will be swapped to auth_token on client)
self.anthropic_api_key = SecretStr(current_key)
# Add required beta headers for OAuth
self.default_headers = {
**(self.default_headers or {}),
"anthropic-beta": OAUTH_ANTHROPIC_BETAS,
}
# OAuth tokens have a limit of 4 cache_control blocks — disable prompt caching
self.enable_prompt_caching = False
logger.info("OAuth token detected — will use Authorization: Bearer header")
else:
if current_key:
self.anthropic_api_key = SecretStr(current_key)
# Ensure api_key is SecretStr
if isinstance(self.anthropic_api_key, str):
self.anthropic_api_key = SecretStr(self.anthropic_api_key)
super().model_post_init(__context)
# Patch clients immediately after creation for OAuth Bearer auth.
# This must happen after super() because clients are lazily created.
if self._is_oauth:
self._patch_client_oauth(self._client)
self._patch_client_oauth(self._async_client)
def _patch_client_oauth(self, client: Any) -> None:
"""Swap api_key → auth_token on an Anthropic SDK client for OAuth Bearer auth."""
if hasattr(client, "api_key") and hasattr(client, "auth_token"):
client.api_key = None
client.auth_token = self._oauth_access_token
def _get_request_payload(
self,
input_: Any,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
"""Override to inject prompt caching, thinking budget, and OAuth billing."""
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
if self._is_oauth:
self._apply_oauth_billing(payload)
if self.enable_prompt_caching:
self._apply_prompt_caching(payload)
if self.auto_thinking_budget:
self._apply_thinking_budget(payload)
return payload
def _apply_oauth_billing(self, payload: dict) -> None:
"""Inject the billing header block required for all OAuth requests.
The billing block is always placed first in the system list, removing any
existing occurrence to avoid duplication or out-of-order positioning.
"""
billing_block = {"type": "text", "text": OAUTH_BILLING_HEADER}
system = payload.get("system")
if isinstance(system, list):
# Remove any existing billing blocks, then insert a single one at index 0.
filtered = [b for b in system if not (isinstance(b, dict) and OAUTH_BILLING_HEADER in b.get("text", ""))]
payload["system"] = [billing_block] + filtered
elif isinstance(system, str):
if OAUTH_BILLING_HEADER in system:
payload["system"] = [billing_block]
else:
payload["system"] = [billing_block, {"type": "text", "text": system}]
else:
payload["system"] = [billing_block]
# Add metadata.user_id required by the API for OAuth billing validation
if not isinstance(payload.get("metadata"), dict):
payload["metadata"] = {}
if "user_id" not in payload["metadata"]:
# Generate a stable device_id from the machine's hostname
hostname = socket.gethostname()
device_id = hashlib.sha256(f"deerflow-{hostname}".encode()).hexdigest()
session_id = str(uuid.uuid4())
payload["metadata"]["user_id"] = json.dumps(
{
"device_id": device_id,
"account_uuid": "deerflow",
"session_id": session_id,
}
)
def _apply_prompt_caching(self, payload: dict) -> None:
"""Apply ephemeral cache_control to system and recent messages."""
# Cache system messages
system = payload.get("system")
if system and isinstance(system, list):
for block in system:
if isinstance(block, dict) and block.get("type") == "text":
block["cache_control"] = {"type": "ephemeral"}
elif system and isinstance(system, str):
payload["system"] = [
{
"type": "text",
"text": system,
"cache_control": {"type": "ephemeral"},
}
]
# Cache recent messages
messages = payload.get("messages", [])
cache_start = max(0, len(messages) - self.prompt_cache_size)
for i in range(cache_start, len(messages)):
msg = messages[i]
if not isinstance(msg, dict):
continue
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
block["cache_control"] = {"type": "ephemeral"}
elif isinstance(content, str) and content:
msg["content"] = [
{
"type": "text",
"text": content,
"cache_control": {"type": "ephemeral"},
}
]
# Cache the last tool definition
tools = payload.get("tools", [])
if tools and isinstance(tools[-1], dict):
tools[-1]["cache_control"] = {"type": "ephemeral"}
def _apply_thinking_budget(self, payload: dict) -> None:
"""Auto-allocate thinking budget (80% of max_tokens)."""
thinking = payload.get("thinking")
if not thinking or not isinstance(thinking, dict):
return
if thinking.get("type") != "enabled":
return
if thinking.get("budget_tokens"):
return
max_tokens = payload.get("max_tokens", 8192)
thinking["budget_tokens"] = int(max_tokens * THINKING_BUDGET_RATIO)
@staticmethod
def _strip_cache_control(payload: dict) -> None:
"""Remove cache_control markers before OAuth requests reach Anthropic."""
for section in ("system", "messages"):
items = payload.get(section)
if not isinstance(items, list):
continue
for item in items:
if not isinstance(item, dict):
continue
item.pop("cache_control", None)
content = item.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
block.pop("cache_control", None)
tools = payload.get("tools")
if isinstance(tools, list):
for tool in tools:
if isinstance(tool, dict):
tool.pop("cache_control", None)
def _create(self, payload: dict) -> Any:
if self._is_oauth:
self._strip_cache_control(payload)
return super()._create(payload)
async def _acreate(self, payload: dict) -> Any:
if self._is_oauth:
self._strip_cache_control(payload)
return await super()._acreate(payload)
def _generate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
"""Override with OAuth patching and retry logic."""
if self._is_oauth:
self._patch_client_oauth(self._client)
last_error = None
for attempt in range(1, self.retry_max_attempts + 1):
try:
return super()._generate(messages, stop=stop, **kwargs)
except anthropic.RateLimitError as e:
last_error = e
if attempt >= self.retry_max_attempts:
raise
wait_ms = self._calc_backoff_ms(attempt, e)
logger.warning(f"Rate limited, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
time.sleep(wait_ms / 1000)
except anthropic.InternalServerError as e:
last_error = e
if attempt >= self.retry_max_attempts:
raise
wait_ms = self._calc_backoff_ms(attempt, e)
logger.warning(f"Server error, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
time.sleep(wait_ms / 1000)
raise last_error
async def _agenerate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
"""Async override with OAuth patching and retry logic."""
import asyncio
if self._is_oauth:
self._patch_client_oauth(self._async_client)
last_error = None
for attempt in range(1, self.retry_max_attempts + 1):
try:
return await super()._agenerate(messages, stop=stop, **kwargs)
except anthropic.RateLimitError as e:
last_error = e
if attempt >= self.retry_max_attempts:
raise
wait_ms = self._calc_backoff_ms(attempt, e)
logger.warning(f"Rate limited, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
await asyncio.sleep(wait_ms / 1000)
except anthropic.InternalServerError as e:
last_error = e
if attempt >= self.retry_max_attempts:
raise
wait_ms = self._calc_backoff_ms(attempt, e)
logger.warning(f"Server error, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
await asyncio.sleep(wait_ms / 1000)
raise last_error
@staticmethod
def _calc_backoff_ms(attempt: int, error: Exception) -> int:
"""Exponential backoff with a fixed 20% buffer."""
backoff_ms = 2000 * (1 << (attempt - 1))
jitter_ms = int(backoff_ms * 0.2)
total_ms = backoff_ms + jitter_ms
if hasattr(error, "response") and error.response is not None:
retry_after = error.response.headers.get("Retry-After")
if retry_after:
try:
total_ms = int(retry_after) * 1000
except (ValueError, TypeError):
pass
return total_ms

View File

@@ -0,0 +1,219 @@
"""Auto-load credentials from Claude Code CLI and Codex CLI.
Implements two credential strategies:
1. Claude Code OAuth token from explicit env vars or an exported credentials file
- Uses Authorization: Bearer header (NOT x-api-key)
- Requires anthropic-beta: oauth-2025-04-20,claude-code-20250219
- Supports $CLAUDE_CODE_OAUTH_TOKEN, $CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR, and $ANTHROPIC_AUTH_TOKEN
- Override path with $CLAUDE_CODE_CREDENTIALS_PATH
2. Codex CLI token from ~/.codex/auth.json
- Uses chatgpt.com/backend-api/codex/responses endpoint
- Supports both legacy top-level tokens and current nested tokens shape
- Override path with $CODEX_AUTH_PATH
"""
import json
import logging
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# Required beta headers for Claude Code OAuth tokens
OAUTH_ANTHROPIC_BETAS = "oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14"
def is_oauth_token(token: str) -> bool:
"""Check if a token is a Claude Code OAuth token (not a standard API key)."""
return isinstance(token, str) and "sk-ant-oat" in token
@dataclass
class ClaudeCodeCredential:
"""Claude Code CLI OAuth credential."""
access_token: str
refresh_token: str = ""
expires_at: int = 0
source: str = ""
@property
def is_expired(self) -> bool:
if self.expires_at <= 0:
return False
return time.time() * 1000 > self.expires_at - 60_000 # 1 min buffer
@dataclass
class CodexCliCredential:
"""Codex CLI credential."""
access_token: str
account_id: str = ""
source: str = ""
def _resolve_credential_path(env_var: str, default_relative_path: str) -> Path:
configured_path = os.getenv(env_var)
if configured_path:
return Path(configured_path).expanduser()
return _home_dir() / default_relative_path
def _home_dir() -> Path:
home = os.getenv("HOME")
if home:
return Path(home).expanduser()
return Path.home()
def _load_json_file(path: Path, label: str) -> dict[str, Any] | None:
if not path.exists():
logger.debug(f"{label} not found: {path}")
return None
if path.is_dir():
logger.warning(f"{label} path is a directory, expected a file: {path}")
return None
try:
return json.loads(path.read_text())
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"Failed to read {label}: {e}")
return None
def _read_secret_from_file_descriptor(env_var: str) -> str | None:
fd_value = os.getenv(env_var)
if not fd_value:
return None
try:
fd = int(fd_value)
except ValueError:
logger.warning(f"{env_var} must be an integer file descriptor, got: {fd_value}")
return None
try:
secret = os.read(fd, 1024 * 1024).decode().strip()
except OSError as e:
logger.warning(f"Failed to read {env_var}: {e}")
return None
return secret or None
def _credential_from_direct_token(access_token: str, source: str) -> ClaudeCodeCredential | None:
token = access_token.strip()
if not token:
return None
return ClaudeCodeCredential(access_token=token, source=source)
def _iter_claude_code_credential_paths() -> list[Path]:
paths: list[Path] = []
override_path = os.getenv("CLAUDE_CODE_CREDENTIALS_PATH")
if override_path:
paths.append(Path(override_path).expanduser())
default_path = _home_dir() / ".claude/.credentials.json"
if not paths or paths[-1] != default_path:
paths.append(default_path)
return paths
def _extract_claude_code_credential(data: dict[str, Any], source: str) -> ClaudeCodeCredential | None:
oauth = data.get("claudeAiOauth", {})
access_token = oauth.get("accessToken", "")
if not access_token:
logger.debug("Claude Code credentials container exists but no accessToken found")
return None
cred = ClaudeCodeCredential(
access_token=access_token,
refresh_token=oauth.get("refreshToken", ""),
expires_at=oauth.get("expiresAt", 0),
source=source,
)
if cred.is_expired:
logger.warning("Claude Code OAuth token is expired. Run 'claude' to refresh.")
return None
return cred
def load_claude_code_credential() -> ClaudeCodeCredential | None:
"""Load OAuth credential from explicit Claude Code handoff sources.
Lookup order:
1. $CLAUDE_CODE_OAUTH_TOKEN or $ANTHROPIC_AUTH_TOKEN
2. $CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR
3. $CLAUDE_CODE_CREDENTIALS_PATH
4. ~/.claude/.credentials.json
Exported credentials files contain:
{
"claudeAiOauth": {
"accessToken": "sk-ant-oat01-...",
"refreshToken": "sk-ant-ort01-...",
"expiresAt": 1773430695128,
"scopes": ["user:inference", ...],
...
}
}
"""
direct_token = os.getenv("CLAUDE_CODE_OAUTH_TOKEN") or os.getenv("ANTHROPIC_AUTH_TOKEN")
if direct_token:
cred = _credential_from_direct_token(direct_token, "claude-cli-env")
if cred:
logger.info("Loaded Claude Code OAuth credential from environment")
return cred
fd_token = _read_secret_from_file_descriptor("CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR")
if fd_token:
cred = _credential_from_direct_token(fd_token, "claude-cli-fd")
if cred:
logger.info("Loaded Claude Code OAuth credential from file descriptor")
return cred
override_path = os.getenv("CLAUDE_CODE_CREDENTIALS_PATH")
override_path_obj = Path(override_path).expanduser() if override_path else None
for cred_path in _iter_claude_code_credential_paths():
data = _load_json_file(cred_path, "Claude Code credentials")
if data is None:
continue
cred = _extract_claude_code_credential(data, "claude-cli-file")
if cred:
source_label = "override path" if override_path_obj is not None and cred_path == override_path_obj else "plaintext file"
logger.info(f"Loaded Claude Code OAuth credential from {source_label} (expires_at={cred.expires_at})")
return cred
return None
def load_codex_cli_credential() -> CodexCliCredential | None:
"""Load credential from Codex CLI (~/.codex/auth.json)."""
cred_path = _resolve_credential_path("CODEX_AUTH_PATH", ".codex/auth.json")
data = _load_json_file(cred_path, "Codex CLI credentials")
if data is None:
return None
tokens = data.get("tokens", {})
if not isinstance(tokens, dict):
tokens = {}
access_token = data.get("access_token") or data.get("token") or tokens.get("access_token", "")
account_id = data.get("account_id") or tokens.get("account_id", "")
if not access_token:
logger.debug("Codex CLI credentials file exists but no token found")
return None
logger.info("Loaded Codex CLI credential")
return CodexCliCredential(
access_token=access_token,
account_id=account_id,
source="codex-cli",
)

View File

@@ -0,0 +1,123 @@
import logging
from langchain.chat_models import BaseChatModel
from deerflow.config import get_app_config
from deerflow.reflection import resolve_class
from deerflow.tracing import build_tracing_callbacks
logger = logging.getLogger(__name__)
def _deep_merge_dicts(base: dict | None, override: dict) -> dict:
"""Recursively merge two dictionaries without mutating the inputs."""
merged = dict(base or {})
for key, value in override.items():
if isinstance(value, dict) and isinstance(merged.get(key), dict):
merged[key] = _deep_merge_dicts(merged[key], value)
else:
merged[key] = value
return merged
def _vllm_disable_chat_template_kwargs(chat_template_kwargs: dict) -> dict:
"""Build the disable payload for vLLM/Qwen chat template kwargs."""
disable_kwargs: dict[str, bool] = {}
if "thinking" in chat_template_kwargs:
disable_kwargs["thinking"] = False
if "enable_thinking" in chat_template_kwargs:
disable_kwargs["enable_thinking"] = False
return disable_kwargs
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
"""Create a chat model instance from the config.
Args:
name: The name of the model to create. If None, the first model in the config will be used.
Returns:
A chat model instance.
"""
config = get_app_config()
if name is None:
name = config.models[0].name
model_config = config.get_model_config(name)
if model_config is None:
raise ValueError(f"Model {name} not found in config") from None
model_class = resolve_class(model_config.use, BaseChatModel)
model_settings_from_config = model_config.model_dump(
exclude_none=True,
exclude={
"use",
"name",
"display_name",
"description",
"supports_thinking",
"supports_reasoning_effort",
"when_thinking_enabled",
"when_thinking_disabled",
"thinking",
"supports_vision",
},
)
# Compute effective when_thinking_enabled by merging in the `thinking` shortcut field.
# The `thinking` shortcut is equivalent to setting when_thinking_enabled["thinking"].
has_thinking_settings = (model_config.when_thinking_enabled is not None) or (model_config.thinking is not None)
effective_wte: dict = dict(model_config.when_thinking_enabled) if model_config.when_thinking_enabled else {}
if model_config.thinking is not None:
merged_thinking = {**(effective_wte.get("thinking") or {}), **model_config.thinking}
effective_wte = {**effective_wte, "thinking": merged_thinking}
if thinking_enabled and has_thinking_settings:
if not model_config.supports_thinking:
raise ValueError(f"Model {name} does not support thinking. Set `supports_thinking` to true in the `config.yaml` to enable thinking.") from None
if effective_wte:
model_settings_from_config.update(effective_wte)
if not thinking_enabled:
if model_config.when_thinking_disabled is not None:
# User-provided disable settings take full precedence
model_settings_from_config.update(model_config.when_thinking_disabled)
elif has_thinking_settings and effective_wte.get("extra_body", {}).get("thinking", {}).get("type"):
# OpenAI-compatible gateway: thinking is nested under extra_body
model_settings_from_config["extra_body"] = _deep_merge_dicts(
model_settings_from_config.get("extra_body"),
{"thinking": {"type": "disabled"}},
)
model_settings_from_config["reasoning_effort"] = "minimal"
elif has_thinking_settings and (disable_chat_template_kwargs := _vllm_disable_chat_template_kwargs(effective_wte.get("extra_body", {}).get("chat_template_kwargs") or {})):
# vLLM uses chat template kwargs to switch thinking on/off.
model_settings_from_config["extra_body"] = _deep_merge_dicts(
model_settings_from_config.get("extra_body"),
{"chat_template_kwargs": disable_chat_template_kwargs},
)
elif has_thinking_settings and effective_wte.get("thinking", {}).get("type"):
# Native langchain_anthropic: thinking is a direct constructor parameter
model_settings_from_config["thinking"] = {"type": "disabled"}
if not model_config.supports_reasoning_effort:
kwargs.pop("reasoning_effort", None)
model_settings_from_config.pop("reasoning_effort", None)
# For Codex Responses API models: map thinking mode to reasoning_effort
from deerflow.models.openai_codex_provider import CodexChatModel
if issubclass(model_class, CodexChatModel):
# The ChatGPT Codex endpoint currently rejects max_tokens/max_output_tokens.
model_settings_from_config.pop("max_tokens", None)
# Use explicit reasoning_effort from frontend if provided (low/medium/high)
explicit_effort = kwargs.pop("reasoning_effort", None)
if not thinking_enabled:
model_settings_from_config["reasoning_effort"] = "none"
elif explicit_effort and explicit_effort in ("low", "medium", "high", "xhigh"):
model_settings_from_config["reasoning_effort"] = explicit_effort
elif "reasoning_effort" not in model_settings_from_config:
model_settings_from_config["reasoning_effort"] = "medium"
model_instance = model_class(**{**model_settings_from_config, **kwargs})
callbacks = build_tracing_callbacks()
if callbacks:
existing_callbacks = model_instance.callbacks or []
model_instance.callbacks = [*existing_callbacks, *callbacks]
logger.debug(f"Tracing attached to model '{name}' with providers={len(callbacks)}")
return model_instance

View File

@@ -0,0 +1,430 @@
"""Custom OpenAI Codex provider using ChatGPT Codex Responses API.
Uses Codex CLI OAuth tokens with chatgpt.com/backend-api/codex/responses endpoint.
This is the same endpoint that the Codex CLI uses internally.
Supports:
- Auto-load credentials from ~/.codex/auth.json
- Responses API format (not Chat Completions)
- Tool calling
- Streaming (required by the endpoint)
- Retry with exponential backoff
"""
import json
import logging
import time
from typing import Any
import httpx
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from deerflow.models.credential_loader import CodexCliCredential, load_codex_cli_credential
logger = logging.getLogger(__name__)
CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
MAX_RETRIES = 3
class CodexChatModel(BaseChatModel):
"""LangChain chat model using ChatGPT Codex Responses API.
Config example:
- name: gpt-5.4
use: deerflow.models.openai_codex_provider:CodexChatModel
model: gpt-5.4
reasoning_effort: medium
"""
model: str = "gpt-5.4"
reasoning_effort: str = "medium"
retry_max_attempts: int = MAX_RETRIES
_access_token: str = ""
_account_id: str = ""
model_config = {"arbitrary_types_allowed": True}
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def _llm_type(self) -> str:
return "codex-responses"
def _validate_retry_config(self) -> None:
if self.retry_max_attempts < 1:
raise ValueError("retry_max_attempts must be >= 1")
def model_post_init(self, __context: Any) -> None:
"""Auto-load Codex CLI credentials."""
self._validate_retry_config()
cred = self._load_codex_auth()
if cred:
self._access_token = cred.access_token
self._account_id = cred.account_id
logger.info(f"Using Codex CLI credential (account: {self._account_id[:8]}...)")
else:
raise ValueError("Codex CLI credential not found. Expected ~/.codex/auth.json or CODEX_AUTH_PATH.")
super().model_post_init(__context)
def _load_codex_auth(self) -> CodexCliCredential | None:
"""Load access_token and account_id from Codex CLI auth."""
return load_codex_cli_credential()
@classmethod
def _normalize_content(cls, content: Any) -> str:
"""Flatten LangChain content blocks into plain text for Codex."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [cls._normalize_content(item) for item in content]
return "\n".join(part for part in parts if part)
if isinstance(content, dict):
for key in ("text", "output"):
value = content.get(key)
if isinstance(value, str):
return value
nested_content = content.get("content")
if nested_content is not None:
return cls._normalize_content(nested_content)
try:
return json.dumps(content, ensure_ascii=False)
except TypeError:
return str(content)
try:
return json.dumps(content, ensure_ascii=False)
except TypeError:
return str(content)
def _convert_messages(self, messages: list[BaseMessage]) -> tuple[str, list[dict]]:
"""Convert LangChain messages to Responses API format.
Returns (instructions, input_items).
"""
instructions_parts: list[str] = []
input_items = []
for msg in messages:
if isinstance(msg, SystemMessage):
content = self._normalize_content(msg.content)
if content:
instructions_parts.append(content)
elif isinstance(msg, HumanMessage):
content = self._normalize_content(msg.content)
input_items.append({"role": "user", "content": content})
elif isinstance(msg, AIMessage):
if msg.content:
content = self._normalize_content(msg.content)
input_items.append({"role": "assistant", "content": content})
if msg.tool_calls:
for tc in msg.tool_calls:
input_items.append(
{
"type": "function_call",
"name": tc["name"],
"arguments": json.dumps(tc["args"]) if isinstance(tc["args"], dict) else tc["args"],
"call_id": tc["id"],
}
)
elif isinstance(msg, ToolMessage):
input_items.append(
{
"type": "function_call_output",
"call_id": msg.tool_call_id,
"output": self._normalize_content(msg.content),
}
)
instructions = "\n\n".join(instructions_parts) or "You are a helpful assistant."
return instructions, input_items
def _convert_tools(self, tools: list[dict]) -> list[dict]:
"""Convert LangChain tool format to Responses API format."""
responses_tools = []
for tool in tools:
if tool.get("type") == "function" and "function" in tool:
fn = tool["function"]
responses_tools.append(
{
"type": "function",
"name": fn["name"],
"description": fn.get("description", ""),
"parameters": fn.get("parameters", {}),
}
)
elif "name" in tool:
responses_tools.append(
{
"type": "function",
"name": tool["name"],
"description": tool.get("description", ""),
"parameters": tool.get("parameters", {}),
}
)
return responses_tools
def _call_codex_api(self, messages: list[BaseMessage], tools: list[dict] | None = None) -> dict:
"""Call the Codex Responses API and return the completed response."""
instructions, input_items = self._convert_messages(messages)
payload = {
"model": self.model,
"instructions": instructions,
"input": input_items,
"store": False,
"stream": True,
"reasoning": {"effort": self.reasoning_effort, "summary": "detailed"} if self.reasoning_effort != "none" else {"effort": "none"},
}
if tools:
payload["tools"] = self._convert_tools(tools)
headers = {
"Authorization": f"Bearer {self._access_token}",
"ChatGPT-Account-ID": self._account_id,
"Content-Type": "application/json",
"Accept": "text/event-stream",
"originator": "codex_cli_rs",
}
last_error = None
for attempt in range(1, self.retry_max_attempts + 1):
try:
return self._stream_response(headers, payload)
except httpx.HTTPStatusError as e:
last_error = e
if e.response.status_code in (429, 500, 529):
if attempt >= self.retry_max_attempts:
raise
wait_ms = 2000 * (1 << (attempt - 1))
logger.warning(f"Codex API error {e.response.status_code}, retrying {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
time.sleep(wait_ms / 1000)
else:
raise
except Exception:
raise
raise last_error
def _stream_response(self, headers: dict, payload: dict) -> dict:
"""Stream SSE from Codex API and collect the final response."""
completed_response = None
streamed_output_items: dict[int, dict[str, Any]] = {}
with httpx.Client(timeout=300) as client:
with client.stream("POST", f"{CODEX_BASE_URL}/responses", headers=headers, json=payload) as resp:
resp.raise_for_status()
for line in resp.iter_lines():
data = self._parse_sse_data_line(line)
if not data:
continue
event_type = data.get("type")
if event_type == "response.output_item.done":
output_index = data.get("output_index")
output_item = data.get("item")
if isinstance(output_index, int) and isinstance(output_item, dict):
streamed_output_items[output_index] = output_item
elif event_type == "response.completed":
completed_response = data["response"]
if not completed_response:
raise RuntimeError("Codex API stream ended without response.completed event")
# ChatGPT Codex can emit the final assistant content only in stream events.
# When response.completed arrives, response.output may still be empty.
if streamed_output_items:
merged_output = []
response_output = completed_response.get("output")
if isinstance(response_output, list):
merged_output = list(response_output)
max_index = max(max(streamed_output_items), len(merged_output) - 1)
if max_index >= 0 and len(merged_output) <= max_index:
merged_output.extend([None] * (max_index + 1 - len(merged_output)))
for output_index, output_item in streamed_output_items.items():
existing_item = merged_output[output_index]
if not isinstance(existing_item, dict):
merged_output[output_index] = output_item
completed_response = dict(completed_response)
completed_response["output"] = [item for item in merged_output if isinstance(item, dict)]
return completed_response
@staticmethod
def _parse_sse_data_line(line: str) -> dict[str, Any] | None:
"""Parse a data line from the SSE stream, skipping terminal markers."""
if not line.startswith("data:"):
return None
raw_data = line[5:].strip()
if not raw_data or raw_data == "[DONE]":
return None
try:
data = json.loads(raw_data)
except json.JSONDecodeError:
logger.debug(f"Skipping non-JSON Codex SSE frame: {raw_data}")
return None
return data if isinstance(data, dict) else None
def _parse_tool_call_arguments(self, output_item: dict[str, Any]) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
"""Parse function-call arguments, surfacing malformed payloads safely."""
raw_arguments = output_item.get("arguments", "{}")
if isinstance(raw_arguments, dict):
return raw_arguments, None
normalized_arguments = raw_arguments or "{}"
try:
parsed_arguments = json.loads(normalized_arguments)
except (TypeError, json.JSONDecodeError) as exc:
return None, {
"type": "invalid_tool_call",
"name": output_item.get("name"),
"args": str(raw_arguments),
"id": output_item.get("call_id"),
"error": f"Failed to parse tool arguments: {exc}",
}
if not isinstance(parsed_arguments, dict):
return None, {
"type": "invalid_tool_call",
"name": output_item.get("name"),
"args": str(raw_arguments),
"id": output_item.get("call_id"),
"error": "Tool arguments must decode to a JSON object.",
}
return parsed_arguments, None
def _parse_response(self, response: dict) -> ChatResult:
"""Parse Codex Responses API response into LangChain ChatResult."""
content = ""
tool_calls = []
invalid_tool_calls = []
reasoning_content = ""
for output_item in response.get("output", []):
if output_item.get("type") == "reasoning":
# Extract reasoning summary text
for summary_item in output_item.get("summary", []):
if isinstance(summary_item, dict) and summary_item.get("type") == "summary_text":
reasoning_content += summary_item.get("text", "")
elif isinstance(summary_item, str):
reasoning_content += summary_item
elif output_item.get("type") == "message":
for part in output_item.get("content", []):
if part.get("type") == "output_text":
content += part.get("text", "")
elif output_item.get("type") == "function_call":
parsed_arguments, invalid_tool_call = self._parse_tool_call_arguments(output_item)
if invalid_tool_call:
invalid_tool_calls.append(invalid_tool_call)
continue
tool_calls.append(
{
"name": output_item["name"],
"args": parsed_arguments or {},
"id": output_item.get("call_id", ""),
"type": "tool_call",
}
)
usage = response.get("usage", {})
additional_kwargs = {}
if reasoning_content:
additional_kwargs["reasoning_content"] = reasoning_content
message = AIMessage(
content=content,
tool_calls=tool_calls if tool_calls else [],
invalid_tool_calls=invalid_tool_calls,
additional_kwargs=additional_kwargs,
response_metadata={
"model": response.get("model", self.model),
"usage": usage,
},
)
return ChatResult(
generations=[ChatGeneration(message=message)],
llm_output={
"token_usage": {
"prompt_tokens": usage.get("input_tokens", 0),
"completion_tokens": usage.get("output_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
},
"model_name": response.get("model", self.model),
},
)
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response using Codex Responses API."""
tools = kwargs.get("tools", None)
response = self._call_codex_api(messages, tools=tools)
return self._parse_response(response)
def bind_tools(self, tools: list, **kwargs: Any) -> Any:
"""Bind tools for function calling."""
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function
formatted_tools = []
for tool in tools:
if isinstance(tool, BaseTool):
try:
fn = convert_to_openai_function(tool)
formatted_tools.append(
{
"type": "function",
"name": fn["name"],
"description": fn.get("description", ""),
"parameters": fn.get("parameters", {}),
}
)
except Exception:
formatted_tools.append(
{
"type": "function",
"name": tool.name,
"description": tool.description,
"parameters": {"type": "object", "properties": {}},
}
)
elif isinstance(tool, dict):
if "function" in tool:
fn = tool["function"]
formatted_tools.append(
{
"type": "function",
"name": fn["name"],
"description": fn.get("description", ""),
"parameters": fn.get("parameters", {}),
}
)
else:
formatted_tools.append(tool)
return RunnableBinding(bound=self, kwargs={"tools": formatted_tools}, **kwargs)

View File

@@ -0,0 +1,73 @@
"""Patched ChatDeepSeek that preserves reasoning_content in multi-turn conversations.
This module provides a patched version of ChatDeepSeek that properly handles
reasoning_content when sending messages back to the API. The original implementation
stores reasoning_content in additional_kwargs but doesn't include it when making
subsequent API calls, which causes errors with APIs that require reasoning_content
on all assistant messages when thinking mode is enabled.
"""
from typing import Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage
from langchain_deepseek import ChatDeepSeek
class PatchedChatDeepSeek(ChatDeepSeek):
"""ChatDeepSeek with proper reasoning_content preservation.
When using thinking/reasoning enabled models, the API expects reasoning_content
to be present on ALL assistant messages in multi-turn conversations. This patched
version ensures reasoning_content from additional_kwargs is included in the
request payload.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def lc_secrets(self) -> dict[str, str]:
return {"api_key": "DEEPSEEK_API_KEY", "openai_api_key": "DEEPSEEK_API_KEY"}
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
"""Get request payload with reasoning_content preserved.
Overrides the parent method to inject reasoning_content from
additional_kwargs into assistant messages in the payload.
"""
# Get the original messages before conversion
original_messages = self._convert_input(input_).to_messages()
# Call parent to get the base payload
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
# Match payload messages with original messages to restore reasoning_content
payload_messages = payload.get("messages", [])
# The payload messages and original messages should be in the same order
# Iterate through both and match by position
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
reasoning_content = orig_msg.additional_kwargs.get("reasoning_content")
if reasoning_content is not None:
payload_msg["reasoning_content"] = reasoning_content
else:
# Fallback: match by counting assistant messages
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
for (idx, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
reasoning_content = ai_msg.additional_kwargs.get("reasoning_content")
if reasoning_content is not None:
payload_messages[idx]["reasoning_content"] = reasoning_content
return payload

View File

@@ -0,0 +1,220 @@
"""Patched ChatOpenAI adapter for MiniMax reasoning output.
MiniMax's OpenAI-compatible chat completions API can return structured
``reasoning_details`` when ``extra_body.reasoning_split=true`` is enabled.
``langchain_openai.ChatOpenAI`` currently ignores that field, so DeerFlow's
frontend never receives reasoning content in the shape it expects.
This adapter preserves ``reasoning_split`` in the request payload and maps the
provider-specific reasoning field into ``additional_kwargs.reasoning_content``,
which DeerFlow already understands.
"""
from __future__ import annotations
import re
from collections.abc import Mapping
from typing import Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (
_convert_delta_to_message_chunk,
_create_usage_metadata,
)
_THINK_TAG_RE = re.compile(r"<think>\s*(.*?)\s*</think>", re.DOTALL)
def _extract_reasoning_text(
reasoning_details: Any,
*,
strip_parts: bool = True,
) -> str | None:
if not isinstance(reasoning_details, list):
return None
parts: list[str] = []
for item in reasoning_details:
if not isinstance(item, Mapping):
continue
text = item.get("text")
if isinstance(text, str):
normalized = text.strip() if strip_parts else text
if normalized.strip():
parts.append(normalized)
return "\n\n".join(parts) if parts else None
def _strip_inline_think_tags(content: str) -> tuple[str, str | None]:
reasoning_parts: list[str] = []
def _replace(match: re.Match[str]) -> str:
reasoning = match.group(1).strip()
if reasoning:
reasoning_parts.append(reasoning)
return ""
cleaned = _THINK_TAG_RE.sub(_replace, content).strip()
reasoning = "\n\n".join(reasoning_parts) if reasoning_parts else None
return cleaned, reasoning
def _merge_reasoning(*values: str | None) -> str | None:
merged: list[str] = []
for value in values:
if not value:
continue
normalized = value.strip()
if normalized and normalized not in merged:
merged.append(normalized)
return "\n\n".join(merged) if merged else None
def _with_reasoning_content(
message: AIMessage | AIMessageChunk,
reasoning: str | None,
*,
preserve_whitespace: bool = False,
):
if not reasoning:
return message
additional_kwargs = dict(message.additional_kwargs)
if preserve_whitespace:
existing = additional_kwargs.get("reasoning_content")
additional_kwargs["reasoning_content"] = f"{existing}{reasoning}" if isinstance(existing, str) else reasoning
else:
additional_kwargs["reasoning_content"] = _merge_reasoning(
additional_kwargs.get("reasoning_content"),
reasoning,
)
return message.model_copy(update={"additional_kwargs": additional_kwargs})
class PatchedChatMiniMax(ChatOpenAI):
"""ChatOpenAI adapter that preserves MiniMax reasoning output."""
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
extra_body = payload.get("extra_body")
if isinstance(extra_body, dict):
payload["extra_body"] = {
**extra_body,
"reasoning_split": True,
}
else:
payload["extra_body"] = {"reasoning_split": True}
return payload
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: type,
base_generation_info: dict | None,
) -> ChatGenerationChunk | None:
if chunk.get("type") == "content.delta":
return None
token_usage = chunk.get("usage")
choices = chunk.get("choices", []) or chunk.get("chunk", {}).get("choices", [])
usage_metadata = _create_usage_metadata(token_usage, chunk.get("service_tier")) if token_usage else None
if len(choices) == 0:
generation_chunk = ChatGenerationChunk(
message=default_chunk_class(content="", usage_metadata=usage_metadata),
generation_info=base_generation_info,
)
if self.output_version == "v1":
generation_chunk.message.content = []
generation_chunk.message.response_metadata["output_version"] = "v1"
return generation_chunk
choice = choices[0]
delta = choice.get("delta")
if delta is None:
return None
message_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
generation_info = {**base_generation_info} if base_generation_info else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
if service_tier := chunk.get("service_tier"):
generation_info["service_tier"] = service_tier
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
reasoning = _extract_reasoning_text(
delta.get("reasoning_details"),
strip_parts=False,
)
if isinstance(message_chunk, AIMessageChunk):
if usage_metadata:
message_chunk.usage_metadata = usage_metadata
if reasoning:
message_chunk = _with_reasoning_content(
message_chunk,
reasoning,
preserve_whitespace=True,
)
message_chunk.response_metadata["model_provider"] = "openai"
return ChatGenerationChunk(
message=message_chunk,
generation_info=generation_info or None,
)
def _create_chat_result(
self,
response: dict | Any,
generation_info: dict | None = None,
) -> ChatResult:
result = super()._create_chat_result(response, generation_info)
response_dict = response if isinstance(response, dict) else response.model_dump()
choices = response_dict.get("choices", [])
generations: list[ChatGeneration] = []
for index, generation in enumerate(result.generations):
choice = choices[index] if index < len(choices) else {}
message = generation.message
if isinstance(message, AIMessage):
content = message.content if isinstance(message.content, str) else None
cleaned_content = content
inline_reasoning = None
if isinstance(content, str):
cleaned_content, inline_reasoning = _strip_inline_think_tags(content)
choice_message = choice.get("message", {}) if isinstance(choice, Mapping) else {}
split_reasoning = _extract_reasoning_text(choice_message.get("reasoning_details"))
merged_reasoning = _merge_reasoning(split_reasoning, inline_reasoning)
updated_message = message
if cleaned_content is not None and cleaned_content != message.content:
updated_message = updated_message.model_copy(update={"content": cleaned_content})
if merged_reasoning:
updated_message = _with_reasoning_content(updated_message, merged_reasoning)
generation = ChatGeneration(
message=updated_message,
generation_info=generation.generation_info,
)
generations.append(generation)
return ChatResult(generations=generations, llm_output=result.llm_output)

View File

@@ -0,0 +1,132 @@
"""Patched ChatOpenAI that preserves thought_signature for Gemini thinking models.
When using Gemini with thinking enabled via an OpenAI-compatible gateway (e.g.
Vertex AI, Google AI Studio, or any proxy), the API requires that the
``thought_signature`` field on tool-call objects is echoed back verbatim in
every subsequent request.
The OpenAI-compatible gateway stores the raw tool-call dicts (including
``thought_signature``) in ``additional_kwargs["tool_calls"]``, but standard
``langchain_openai.ChatOpenAI`` only serialises the standard fields (``id``,
``type``, ``function``) into the outgoing payload, silently dropping the
signature. That causes an HTTP 400 ``INVALID_ARGUMENT`` error:
Unable to submit request because function call `<tool>` in the N. content
block is missing a `thought_signature`.
This module fixes the problem by overriding ``_get_request_payload`` to
re-inject tool-call signatures back into the outgoing payload for any assistant
message that originally carried them.
"""
from __future__ import annotations
from typing import Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
class PatchedChatOpenAI(ChatOpenAI):
"""ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway.
When using Gemini with thinking enabled via an OpenAI-compatible gateway,
the API expects ``thought_signature`` to be present on tool-call objects in
multi-turn conversations. This patched version restores those signatures
from ``AIMessage.additional_kwargs["tool_calls"]`` into the serialised
request payload before it is sent to the API.
Usage in ``config.yaml``::
- name: gemini-2.5-pro-thinking
display_name: Gemini 2.5 Pro (Thinking)
use: deerflow.models.patched_openai:PatchedChatOpenAI
model: google/gemini-2.5-pro-preview
api_key: $GEMINI_API_KEY
base_url: https://<your-openai-compat-gateway>/v1
max_tokens: 16384
supports_thinking: true
supports_vision: true
when_thinking_enabled:
extra_body:
thinking:
type: enabled
"""
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
"""Get request payload with ``thought_signature`` preserved on tool-call objects.
Overrides the parent method to re-inject ``thought_signature`` fields
on tool-call objects that were stored in
``additional_kwargs["tool_calls"]`` by LangChain but dropped during
serialisation.
"""
# Capture the original LangChain messages *before* conversion so we can
# access fields that the serialiser might drop.
original_messages = self._convert_input(input_).to_messages()
# Obtain the base payload from the parent implementation.
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
payload_messages = payload.get("messages", [])
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
_restore_tool_call_signatures(payload_msg, orig_msg)
else:
# Fallback: match assistant-role entries positionally against AIMessages.
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
for (_, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
_restore_tool_call_signatures(payload_msg, ai_msg)
return payload
def _restore_tool_call_signatures(payload_msg: dict, orig_msg: AIMessage) -> None:
"""Re-inject ``thought_signature`` onto tool-call objects in *payload_msg*.
When the Gemini OpenAI-compatible gateway returns a response with function
calls, each tool-call object may carry a ``thought_signature``. LangChain
stores the raw tool-call dicts in ``additional_kwargs["tool_calls"]`` but
only serialises the standard fields (``id``, ``type``, ``function``) into
the outgoing payload, silently dropping the signature.
This function matches raw tool-call entries (by ``id``, falling back to
positional order) and copies the signature back onto the serialised
payload entries.
"""
raw_tool_calls: list[dict] = orig_msg.additional_kwargs.get("tool_calls") or []
payload_tool_calls: list[dict] = payload_msg.get("tool_calls") or []
if not raw_tool_calls or not payload_tool_calls:
return
# Build an id → raw_tc lookup for efficient matching.
raw_by_id: dict[str, dict] = {}
for raw_tc in raw_tool_calls:
tc_id = raw_tc.get("id")
if tc_id:
raw_by_id[tc_id] = raw_tc
for idx, payload_tc in enumerate(payload_tool_calls):
# Try matching by id first, then fall back to positional.
raw_tc = raw_by_id.get(payload_tc.get("id", ""))
if raw_tc is None and idx < len(raw_tool_calls):
raw_tc = raw_tool_calls[idx]
if raw_tc is None:
continue
# The gateway may use either snake_case or camelCase.
sig = raw_tc.get("thought_signature") or raw_tc.get("thoughtSignature")
if sig:
payload_tc["thought_signature"] = sig

View File

@@ -0,0 +1,258 @@
"""Custom vLLM provider built on top of LangChain ChatOpenAI.
vLLM 0.19.0 exposes reasoning models through an OpenAI-compatible API, but
LangChain's default OpenAI adapter drops the non-standard ``reasoning`` field
from assistant messages and streaming deltas. That breaks interleaved
thinking/tool-call flows because vLLM expects the assistant's prior reasoning to
be echoed back on subsequent turns.
This provider preserves ``reasoning`` on:
- non-streaming responses
- streaming deltas
- multi-turn request payloads
"""
from __future__ import annotations
import json
from collections.abc import Mapping
from typing import Any, cast
import openai
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _create_usage_metadata
def _normalize_vllm_chat_template_kwargs(payload: dict[str, Any]) -> None:
"""Map DeerFlow's legacy ``thinking`` toggle to vLLM/Qwen's ``enable_thinking``.
DeerFlow originally documented ``extra_body.chat_template_kwargs.thinking``
for vLLM, but vLLM 0.19.0's Qwen reasoning parser reads
``chat_template_kwargs.enable_thinking``. Normalize the payload just before
it is sent so existing configs keep working and flash mode can truly
disable reasoning.
"""
extra_body = payload.get("extra_body")
if not isinstance(extra_body, dict):
return
chat_template_kwargs = extra_body.get("chat_template_kwargs")
if not isinstance(chat_template_kwargs, dict):
return
if "thinking" not in chat_template_kwargs:
return
normalized_chat_template_kwargs = dict(chat_template_kwargs)
normalized_chat_template_kwargs.setdefault("enable_thinking", normalized_chat_template_kwargs["thinking"])
normalized_chat_template_kwargs.pop("thinking", None)
extra_body["chat_template_kwargs"] = normalized_chat_template_kwargs
def _reasoning_to_text(reasoning: Any) -> str:
"""Best-effort extraction of readable reasoning text from vLLM payloads."""
if isinstance(reasoning, str):
return reasoning
if isinstance(reasoning, list):
parts = [_reasoning_to_text(item) for item in reasoning]
return "".join(part for part in parts if part)
if isinstance(reasoning, dict):
for key in ("text", "content", "reasoning"):
value = reasoning.get(key)
if isinstance(value, str):
return value
if value is not None:
text = _reasoning_to_text(value)
if text:
return text
try:
return json.dumps(reasoning, ensure_ascii=False)
except TypeError:
return str(reasoning)
try:
return json.dumps(reasoning, ensure_ascii=False)
except TypeError:
return str(reasoning)
def _convert_delta_to_message_chunk_with_reasoning(_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]) -> BaseMessageChunk:
"""Convert a streaming delta to a LangChain message chunk while preserving reasoning."""
id_ = _dict.get("id")
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: dict[str, Any] = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
reasoning = _dict.get("reasoning")
if reasoning is not None:
additional_kwargs["reasoning"] = reasoning
reasoning_text = _reasoning_to_text(reasoning)
if reasoning_text:
additional_kwargs["reasoning_content"] = reasoning_text
tool_call_chunks = []
if raw_tool_calls := _dict.get("tool_calls"):
try:
tool_call_chunks = [
tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc["index"],
)
for rtc in raw_tool_calls
]
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content, id=id_)
if role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
id=id_,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
)
if role in ("system", "developer") or default_class == SystemMessageChunk:
role_kwargs = {"__openai_role__": "developer"} if role == "developer" else {}
return SystemMessageChunk(content=content, id=id_, additional_kwargs=role_kwargs)
if role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
if role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"], id=id_)
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role, id=id_) # type: ignore[arg-type]
return default_class(content=content, id=id_) # type: ignore[call-arg]
def _restore_reasoning_field(payload_msg: dict[str, Any], orig_msg: AIMessage) -> None:
"""Re-inject vLLM reasoning onto outgoing assistant messages."""
reasoning = orig_msg.additional_kwargs.get("reasoning")
if reasoning is None:
reasoning = orig_msg.additional_kwargs.get("reasoning_content")
if reasoning is not None:
payload_msg["reasoning"] = reasoning
class VllmChatModel(ChatOpenAI):
"""ChatOpenAI variant that preserves vLLM reasoning fields across turns."""
model_config = {"arbitrary_types_allowed": True}
@property
def _llm_type(self) -> str:
return "vllm-openai-compatible"
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Restore assistant reasoning in request payloads for interleaved thinking."""
original_messages = self._convert_input(input_).to_messages()
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
_normalize_vllm_chat_template_kwargs(payload)
payload_messages = payload.get("messages", [])
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
_restore_reasoning_field(payload_msg, orig_msg)
else:
ai_messages = [message for message in original_messages if isinstance(message, AIMessage)]
assistant_payloads = [message for message in payload_messages if message.get("role") == "assistant"]
for payload_msg, ai_msg in zip(assistant_payloads, ai_messages):
_restore_reasoning_field(payload_msg, ai_msg)
return payload
def _create_chat_result(self, response: dict | openai.BaseModel, generation_info: dict | None = None) -> ChatResult:
"""Preserve vLLM reasoning on non-streaming responses."""
result = super()._create_chat_result(response, generation_info=generation_info)
response_dict = response if isinstance(response, dict) else response.model_dump()
for generation, choice in zip(result.generations, response_dict.get("choices", [])):
if not isinstance(generation, ChatGeneration):
continue
message = generation.message
if not isinstance(message, AIMessage):
continue
reasoning = choice.get("message", {}).get("reasoning")
if reasoning is None:
continue
message.additional_kwargs["reasoning"] = reasoning
reasoning_text = _reasoning_to_text(reasoning)
if reasoning_text:
message.additional_kwargs["reasoning_content"] = reasoning_text
return result
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: type,
base_generation_info: dict | None,
) -> ChatGenerationChunk | None:
"""Preserve vLLM reasoning on streaming deltas."""
if chunk.get("type") == "content.delta":
return None
token_usage = chunk.get("usage")
choices = chunk.get("choices", []) or chunk.get("chunk", {}).get("choices", [])
usage_metadata = _create_usage_metadata(token_usage, chunk.get("service_tier")) if token_usage else None
if len(choices) == 0:
generation_chunk = ChatGenerationChunk(message=default_chunk_class(content="", usage_metadata=usage_metadata), generation_info=base_generation_info)
if self.output_version == "v1":
generation_chunk.message.content = []
generation_chunk.message.response_metadata["output_version"] = "v1"
return generation_chunk
choice = choices[0]
if choice["delta"] is None:
return None
message_chunk = _convert_delta_to_message_chunk_with_reasoning(choice["delta"], default_chunk_class)
generation_info = {**base_generation_info} if base_generation_info else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
if service_tier := chunk.get("service_tier"):
generation_info["service_tier"] = service_tier
if logprobs := choice.get("logprobs"):
generation_info["logprobs"] = logprobs
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
message_chunk.usage_metadata = usage_metadata
message_chunk.response_metadata["model_provider"] = "openai"
return ChatGenerationChunk(message=message_chunk, generation_info=generation_info or None)

View File

@@ -0,0 +1,3 @@
from .resolvers import resolve_class, resolve_variable
__all__ = ["resolve_class", "resolve_variable"]

View File

@@ -0,0 +1,95 @@
from importlib import import_module
MODULE_TO_PACKAGE_HINTS = {
"langchain_google_genai": "langchain-google-genai",
"langchain_anthropic": "langchain-anthropic",
"langchain_openai": "langchain-openai",
"langchain_deepseek": "langchain-deepseek",
}
def _build_missing_dependency_hint(module_path: str, err: ImportError) -> str:
"""Build an actionable hint when module import fails."""
module_root = module_path.split(".", 1)[0]
missing_module = getattr(err, "name", None) or module_root
# Prefer provider package hints for known integrations, even when the import
# error is triggered by a transitive dependency (e.g. `google`).
package_name = MODULE_TO_PACKAGE_HINTS.get(module_root)
if package_name is None:
package_name = MODULE_TO_PACKAGE_HINTS.get(missing_module, missing_module.replace("_", "-"))
return f"Missing dependency '{missing_module}'. Install it with `uv add {package_name}` (or `pip install {package_name}`), then restart DeerFlow."
def resolve_variable[T](
variable_path: str,
expected_type: type[T] | tuple[type, ...] | None = None,
) -> T:
"""Resolve a variable from a path.
Args:
variable_path: The path to the variable (e.g. "parent_package_name.sub_package_name.module_name:variable_name").
expected_type: Optional type or tuple of types to validate the resolved variable against.
If provided, uses isinstance() to check if the variable is an instance of the expected type(s).
Returns:
The resolved variable.
Raises:
ImportError: If the module path is invalid or the attribute doesn't exist.
ValueError: If the resolved variable doesn't pass the validation checks.
"""
try:
module_path, variable_name = variable_path.rsplit(":", 1)
except ValueError as err:
raise ImportError(f"{variable_path} doesn't look like a variable path. Example: parent_package_name.sub_package_name.module_name:variable_name") from err
try:
module = import_module(module_path)
except ImportError as err:
module_root = module_path.split(".", 1)[0]
err_name = getattr(err, "name", None)
if isinstance(err, ModuleNotFoundError) or err_name == module_root:
hint = _build_missing_dependency_hint(module_path, err)
raise ImportError(f"Could not import module {module_path}. {hint}") from err
# Preserve the original ImportError message for non-missing-module failures.
raise ImportError(f"Error importing module {module_path}: {err}") from err
try:
variable = getattr(module, variable_name)
except AttributeError as err:
raise ImportError(f"Module {module_path} does not define a {variable_name} attribute/class") from err
# Type validation
if expected_type is not None:
if not isinstance(variable, expected_type):
type_name = expected_type.__name__ if isinstance(expected_type, type) else " or ".join(t.__name__ for t in expected_type)
raise ValueError(f"{variable_path} is not an instance of {type_name}, got {type(variable).__name__}")
return variable
def resolve_class[T](class_path: str, base_class: type[T] | None = None) -> type[T]:
"""Resolve a class from a module path and class name.
Args:
class_path: The path to the class (e.g. "langchain_openai:ChatOpenAI").
base_class: The base class to check if the resolved class is a subclass of.
Returns:
The resolved class.
Raises:
ImportError: If the module path is invalid or the attribute doesn't exist.
ValueError: If the resolved object is not a class or not a subclass of base_class.
"""
model_class = resolve_variable(class_path, expected_type=type)
if not isinstance(model_class, type):
raise ValueError(f"{class_path} is not a valid class")
if base_class is not None and not issubclass(model_class, base_class):
raise ValueError(f"{class_path} is not a subclass of {base_class.__name__}")
return model_class

View File

@@ -0,0 +1,39 @@
"""LangGraph-compatible runtime — runs, streaming, and lifecycle management.
Re-exports the public API of :mod:`~deerflow.runtime.runs` and
:mod:`~deerflow.runtime.stream_bridge` so that consumers can import
directly from ``deerflow.runtime``.
"""
from .runs import ConflictError, DisconnectMode, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
from .store import get_store, make_store, reset_store, store_context
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
__all__ = [
# runs
"ConflictError",
"DisconnectMode",
"RunManager",
"RunRecord",
"RunStatus",
"UnsupportedStrategyError",
"run_agent",
# serialization
"serialize",
"serialize_channel_values",
"serialize_lc_object",
"serialize_messages_tuple",
# store
"get_store",
"make_store",
"reset_store",
"store_context",
# stream_bridge
"END_SENTINEL",
"HEARTBEAT_SENTINEL",
"MemoryStreamBridge",
"StreamBridge",
"StreamEvent",
"make_stream_bridge",
]

View File

@@ -0,0 +1,15 @@
"""Run lifecycle management for LangGraph Platform API compatibility."""
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
from .schemas import DisconnectMode, RunStatus
from .worker import run_agent
__all__ = [
"ConflictError",
"DisconnectMode",
"RunManager",
"RunRecord",
"RunStatus",
"UnsupportedStrategyError",
"run_agent",
]

View File

@@ -0,0 +1,210 @@
"""In-memory run registry."""
from __future__ import annotations
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from datetime import UTC, datetime
from .schemas import DisconnectMode, RunStatus
logger = logging.getLogger(__name__)
def _now_iso() -> str:
return datetime.now(UTC).isoformat()
@dataclass
class RunRecord:
"""Mutable record for a single run."""
run_id: str
thread_id: str
assistant_id: str | None
status: RunStatus
on_disconnect: DisconnectMode
multitask_strategy: str = "reject"
metadata: dict = field(default_factory=dict)
kwargs: dict = field(default_factory=dict)
created_at: str = ""
updated_at: str = ""
task: asyncio.Task | None = field(default=None, repr=False)
abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
abort_action: str = "interrupt"
error: str | None = None
class RunManager:
"""In-memory run registry. All mutations are protected by an asyncio lock."""
def __init__(self) -> None:
self._runs: dict[str, RunRecord] = {}
self._lock = asyncio.Lock()
async def create(
self,
thread_id: str,
assistant_id: str | None = None,
*,
on_disconnect: DisconnectMode = DisconnectMode.cancel,
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
) -> RunRecord:
"""Create a new pending run and register it."""
run_id = str(uuid.uuid4())
now = _now_iso()
record = RunRecord(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending,
on_disconnect=on_disconnect,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=now,
updated_at=now,
)
async with self._lock:
self._runs[run_id] = record
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
def get(self, run_id: str) -> RunRecord | None:
"""Return a run record by ID, or ``None``."""
return self._runs.get(run_id)
async def list_by_thread(self, thread_id: str) -> list[RunRecord]:
"""Return all runs for a given thread, newest first."""
async with self._lock:
# Dict insertion order matches creation order, so reversing it gives
# us deterministic newest-first results even when timestamps tie.
return [r for r in reversed(self._runs.values()) if r.thread_id == thread_id]
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
"""Transition a run to a new status."""
async with self._lock:
record = self._runs.get(run_id)
if record is None:
logger.warning("set_status called for unknown run %s", run_id)
return
record.status = status
record.updated_at = _now_iso()
if error is not None:
record.error = error
logger.info("Run %s -> %s", run_id, status.value)
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
"""Request cancellation of a run.
Args:
run_id: The run ID to cancel.
action: "interrupt" keeps checkpoint, "rollback" reverts to pre-run state.
Sets the abort event with the action reason and cancels the asyncio task.
Returns ``True`` if the run was in-flight and cancellation was initiated.
"""
async with self._lock:
record = self._runs.get(run_id)
if record is None:
return False
if record.status not in (RunStatus.pending, RunStatus.running):
return False
record.abort_action = action
record.abort_event.set()
if record.task is not None and not record.task.done():
record.task.cancel()
record.status = RunStatus.interrupted
record.updated_at = _now_iso()
logger.info("Run %s cancelled (action=%s)", run_id, action)
return True
async def create_or_reject(
self,
thread_id: str,
assistant_id: str | None = None,
*,
on_disconnect: DisconnectMode = DisconnectMode.cancel,
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
) -> RunRecord:
"""Atomically check for inflight runs and create a new one.
For ``reject`` strategy, raises ``ConflictError`` if thread
already has a pending/running run. For ``interrupt``/``rollback``,
cancels inflight runs before creating.
This method holds the lock across both the check and the insert,
eliminating the TOCTOU race in separate ``has_inflight`` + ``create``.
"""
run_id = str(uuid.uuid4())
now = _now_iso()
_supported_strategies = ("reject", "interrupt", "rollback")
async with self._lock:
if multitask_strategy not in _supported_strategies:
raise UnsupportedStrategyError(f"Multitask strategy '{multitask_strategy}' is not yet supported. Supported strategies: {', '.join(_supported_strategies)}")
inflight = [r for r in self._runs.values() if r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running)]
if multitask_strategy == "reject" and inflight:
raise ConflictError(f"Thread {thread_id} already has an active run")
if multitask_strategy in ("interrupt", "rollback") and inflight:
for r in inflight:
r.abort_action = multitask_strategy
r.abort_event.set()
if r.task is not None and not r.task.done():
r.task.cancel()
r.status = RunStatus.interrupted
r.updated_at = now
logger.info(
"Cancelled %d inflight run(s) on thread %s (strategy=%s)",
len(inflight),
thread_id,
multitask_strategy,
)
record = RunRecord(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending,
on_disconnect=on_disconnect,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=now,
updated_at=now,
)
self._runs[run_id] = record
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
async def has_inflight(self, thread_id: str) -> bool:
"""Return ``True`` if *thread_id* has a pending or running run."""
async with self._lock:
return any(r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running) for r in self._runs.values())
async def cleanup(self, run_id: str, *, delay: float = 300) -> None:
"""Remove a run record after an optional delay."""
if delay > 0:
await asyncio.sleep(delay)
async with self._lock:
self._runs.pop(run_id, None)
logger.debug("Run record %s cleaned up", run_id)
class ConflictError(Exception):
"""Raised when multitask_strategy=reject and thread has inflight runs."""
class UnsupportedStrategyError(Exception):
"""Raised when a multitask_strategy value is not yet implemented."""

View File

@@ -0,0 +1,21 @@
"""Run status and disconnect mode enums."""
from enum import StrEnum
class RunStatus(StrEnum):
"""Lifecycle status of a single run."""
pending = "pending"
running = "running"
success = "success"
error = "error"
timeout = "timeout"
interrupted = "interrupted"
class DisconnectMode(StrEnum):
"""Behaviour when the SSE consumer disconnects."""
cancel = "cancel"
continue_ = "continue"

View File

@@ -0,0 +1,381 @@
"""Background agent execution.
Runs an agent graph inside an ``asyncio.Task``, publishing events to
a :class:`StreamBridge` as they are produced.
Uses ``graph.astream(stream_mode=[...])`` which gives correct full-state
snapshots for ``values`` mode, proper ``{node: writes}`` for ``updates``,
and ``(chunk, metadata)`` tuples for ``messages`` mode.
Note: ``events`` mode is not supported through the gateway — it requires
``graph.astream_events()`` which cannot simultaneously produce ``values``
snapshots. The JS open-source LangGraph API server works around this via
internal checkpoint callbacks that are not exposed in the Python public API.
"""
from __future__ import annotations
import asyncio
import copy
import inspect
import logging
from typing import Any, Literal
from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge
from .manager import RunManager, RunRecord
from .schemas import RunStatus
logger = logging.getLogger(__name__)
# Valid stream_mode values for LangGraph's graph.astream()
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
async def run_agent(
bridge: StreamBridge,
run_manager: RunManager,
record: RunRecord,
*,
checkpointer: Any,
store: Any | None = None,
agent_factory: Any,
graph_input: dict,
config: dict,
stream_modes: list[str] | None = None,
stream_subgraphs: bool = False,
interrupt_before: list[str] | Literal["*"] | None = None,
interrupt_after: list[str] | Literal["*"] | None = None,
) -> None:
"""Execute an agent in the background, publishing events to *bridge*."""
run_id = record.run_id
thread_id = record.thread_id
requested_modes: set[str] = set(stream_modes or ["values"])
pre_run_checkpoint_id: str | None = None
pre_run_snapshot: dict[str, Any] | None = None
snapshot_capture_failed = False
# Track whether "events" was requested but skipped
if "events" in requested_modes:
logger.info(
"Run %s: 'events' stream_mode not supported in gateway (requires astream_events + checkpoint callbacks). Skipping.",
run_id,
)
try:
# 1. Mark running
await run_manager.set_status(run_id, RunStatus.running)
# Snapshot the latest pre-run checkpoint so rollback can restore it.
if checkpointer is not None:
try:
config_for_check = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(config_for_check)
if ckpt_tuple is not None:
ckpt_config = getattr(ckpt_tuple, "config", {}).get("configurable", {})
pre_run_checkpoint_id = ckpt_config.get("checkpoint_id")
pre_run_snapshot = {
"checkpoint_ns": ckpt_config.get("checkpoint_ns", ""),
"checkpoint": copy.deepcopy(getattr(ckpt_tuple, "checkpoint", {})),
"metadata": copy.deepcopy(getattr(ckpt_tuple, "metadata", {})),
"pending_writes": copy.deepcopy(getattr(ckpt_tuple, "pending_writes", []) or []),
}
except Exception:
snapshot_capture_failed = True
logger.warning("Could not capture pre-run checkpoint snapshot for run %s", run_id, exc_info=True)
# 2. Publish metadata — useStream needs both run_id AND thread_id
await bridge.publish(
run_id,
"metadata",
{
"run_id": run_id,
"thread_id": thread_id,
},
)
# 3. Build the agent
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
# Inject runtime context so middlewares can access thread_id
# (langgraph-cli does this automatically; we must do it manually)
runtime = Runtime(context={"thread_id": thread_id}, store=store)
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
# prefers it over ``configurable`` for thread-level data), make
# sure ``thread_id`` is available there too.
if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
runnable_config = RunnableConfig(**config)
agent = agent_factory(config=runnable_config)
# 4. Attach checkpointer and store
if checkpointer is not None:
agent.checkpointer = checkpointer
if store is not None:
agent.store = store
# 5. Set interrupt nodes
if interrupt_before:
agent.interrupt_before_nodes = interrupt_before
if interrupt_after:
agent.interrupt_after_nodes = interrupt_after
# 6. Build LangGraph stream_mode list
# "events" is NOT a valid astream mode — skip it
# "messages-tuple" maps to LangGraph's "messages" mode
lg_modes: list[str] = []
for m in requested_modes:
if m == "messages-tuple":
lg_modes.append("messages")
elif m == "events":
# Skipped — see log above
continue
elif m in _VALID_LG_MODES:
lg_modes.append(m)
if not lg_modes:
lg_modes = ["values"]
# Deduplicate while preserving order
seen: set[str] = set()
deduped: list[str] = []
for m in lg_modes:
if m not in seen:
seen.add(m)
deduped.append(m)
lg_modes = deduped
logger.info("Run %s: streaming with modes %s (requested: %s)", run_id, lg_modes, requested_modes)
# 7. Stream using graph.astream
if len(lg_modes) == 1 and not stream_subgraphs:
# Single mode, no subgraphs: astream yields raw chunks
single_mode = lg_modes[0]
async for chunk in agent.astream(graph_input, config=runnable_config, stream_mode=single_mode):
if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id)
break
sse_event = _lg_mode_to_sse_event(single_mode)
await bridge.publish(run_id, sse_event, serialize(chunk, mode=single_mode))
else:
# Multiple modes or subgraphs: astream yields tuples
async for item in agent.astream(
graph_input,
config=runnable_config,
stream_mode=lg_modes,
subgraphs=stream_subgraphs,
):
if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id)
break
mode, chunk = _unpack_stream_item(item, lg_modes, stream_subgraphs)
if mode is None:
continue
sse_event = _lg_mode_to_sse_event(mode)
await bridge.publish(run_id, sse_event, serialize(chunk, mode=mode))
# 8. Final status
if record.abort_event.is_set():
action = record.abort_action
if action == "rollback":
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
try:
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id=thread_id,
run_id=run_id,
pre_run_checkpoint_id=pre_run_checkpoint_id,
pre_run_snapshot=pre_run_snapshot,
snapshot_capture_failed=snapshot_capture_failed,
)
logger.info("Run %s rolled back to pre-run checkpoint %s", run_id, pre_run_checkpoint_id)
except Exception:
logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True)
else:
await run_manager.set_status(run_id, RunStatus.interrupted)
else:
await run_manager.set_status(run_id, RunStatus.success)
except asyncio.CancelledError:
action = record.abort_action
if action == "rollback":
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
try:
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id=thread_id,
run_id=run_id,
pre_run_checkpoint_id=pre_run_checkpoint_id,
pre_run_snapshot=pre_run_snapshot,
snapshot_capture_failed=snapshot_capture_failed,
)
logger.info("Run %s was cancelled and rolled back", run_id)
except Exception:
logger.warning("Run %s cancellation rollback failed", run_id, exc_info=True)
else:
await run_manager.set_status(run_id, RunStatus.interrupted)
logger.info("Run %s was cancelled", run_id)
except Exception as exc:
error_msg = f"{exc}"
logger.exception("Run %s failed: %s", run_id, error_msg)
await run_manager.set_status(run_id, RunStatus.error, error=error_msg)
await bridge.publish(
run_id,
"error",
{
"message": error_msg,
"name": type(exc).__name__,
},
)
finally:
await bridge.publish_end(run_id)
asyncio.create_task(bridge.cleanup(run_id, delay=60))
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _call_checkpointer_method(checkpointer: Any, async_name: str, sync_name: str, *args: Any, **kwargs: Any) -> Any:
"""Call a checkpointer method, supporting async and sync variants."""
method = getattr(checkpointer, async_name, None) or getattr(checkpointer, sync_name, None)
if method is None:
raise AttributeError(f"Missing checkpointer method: {async_name}/{sync_name}")
result = method(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result
async def _rollback_to_pre_run_checkpoint(
*,
checkpointer: Any,
thread_id: str,
run_id: str,
pre_run_checkpoint_id: str | None,
pre_run_snapshot: dict[str, Any] | None,
snapshot_capture_failed: bool,
) -> None:
"""Restore thread state to the checkpoint snapshot captured before run start."""
if checkpointer is None:
logger.info("Run %s rollback requested but no checkpointer is configured", run_id)
return
if snapshot_capture_failed:
logger.warning("Run %s rollback skipped: pre-run checkpoint snapshot capture failed", run_id)
return
if pre_run_snapshot is None:
await _call_checkpointer_method(checkpointer, "adelete_thread", "delete_thread", thread_id)
logger.info("Run %s rollback reset thread %s to empty state", run_id, thread_id)
return
checkpoint_to_restore = None
metadata_to_restore: dict[str, Any] = {}
checkpoint_ns = ""
checkpoint = pre_run_snapshot.get("checkpoint")
if not isinstance(checkpoint, dict):
logger.warning("Run %s rollback skipped: invalid pre-run checkpoint snapshot", run_id)
return
checkpoint_to_restore = checkpoint
if checkpoint_to_restore.get("id") is None and pre_run_checkpoint_id is not None:
checkpoint_to_restore = {**checkpoint_to_restore, "id": pre_run_checkpoint_id}
if checkpoint_to_restore.get("id") is None:
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
return
metadata = pre_run_snapshot.get("metadata", {})
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
checkpoint_ns = raw_checkpoint_ns if isinstance(raw_checkpoint_ns, str) else ""
channel_versions = checkpoint_to_restore.get("channel_versions")
new_versions = dict(channel_versions) if isinstance(channel_versions, dict) else {}
restore_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}}
restored_config = await _call_checkpointer_method(
checkpointer,
"aput",
"put",
restore_config,
checkpoint_to_restore,
metadata_to_restore if isinstance(metadata_to_restore, dict) else {},
new_versions,
)
if not isinstance(restored_config, dict):
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config: expected dict")
restored_configurable = restored_config.get("configurable", {})
if not isinstance(restored_configurable, dict):
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config payload")
restored_checkpoint_id = restored_configurable.get("checkpoint_id")
if not restored_checkpoint_id:
raise RuntimeError(f"Run {run_id} rollback restore did not return checkpoint_id")
pending_writes = pre_run_snapshot.get("pending_writes", [])
if not pending_writes:
return
writes_by_task: dict[str, list[tuple[str, Any]]] = {}
for item in pending_writes:
if not isinstance(item, (tuple, list)) or len(item) != 3:
raise RuntimeError(f"Run {run_id} rollback failed: pending_write is not a 3-tuple: {item!r}")
task_id, channel, value = item
if not isinstance(channel, str):
raise RuntimeError(f"Run {run_id} rollback failed: pending_write has non-string channel: task_id={task_id!r}, channel={channel!r}")
writes_by_task.setdefault(str(task_id), []).append((channel, value))
for task_id, writes in writes_by_task.items():
await _call_checkpointer_method(
checkpointer,
"aput_writes",
"put_writes",
restored_config,
writes,
task_id=task_id,
)
def _lg_mode_to_sse_event(mode: str) -> str:
"""Map LangGraph internal stream_mode name to SSE event name.
LangGraph's ``astream(stream_mode="messages")`` produces message
tuples. The SSE protocol calls this ``messages-tuple`` when the
client explicitly requests it, but the default SSE event name used
by LangGraph Platform is simply ``"messages"``.
"""
# All LG modes map 1:1 to SSE event names — "messages" stays "messages"
return mode
def _unpack_stream_item(
item: Any,
lg_modes: list[str],
stream_subgraphs: bool,
) -> tuple[str | None, Any]:
"""Unpack a multi-mode or subgraph stream item into (mode, chunk).
Returns ``(None, None)`` if the item cannot be parsed.
"""
if stream_subgraphs:
if isinstance(item, tuple) and len(item) == 3:
_ns, mode, chunk = item
return str(mode), chunk
if isinstance(item, tuple) and len(item) == 2:
mode, chunk = item
return str(mode), chunk
return None, None
if isinstance(item, tuple) and len(item) == 2:
mode, chunk = item
return str(mode), chunk
# Fallback: single-element output from first mode
return lg_modes[0] if lg_modes else None, item

Some files were not shown because too many files have changed in this diff Show More