Initial commit: hardened DeerFlow factory
Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection hardening: - New deerflow.security package: content_delimiter, html_cleaner, sanitizer (8 layers — invisible chars, control chars, symbols, NFC, PUA, tag chars, horizontal whitespace collapse with newline/tab preservation, length cap) - New deerflow.community.searx package: web_search, web_fetch, image_search backed by a private SearX instance, every external string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>> delimiters - All native community web providers (ddg_search, tavily, exa, firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail stubs that raise NativeWebToolDisabledError at import time, so a misconfigured tool.use path fails loud rather than silently falling back to unsanitized output - Native client back-doors (jina_client.py, infoquest_client.py) stubbed too - Native-tool tests quarantined under tests/_disabled_native/ (collect_ignore_glob via local conftest.py) - Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve newlines and tabs so list/table structure survives - Hardened runtime config.yaml references only the searx-backed tools - Factory overlay (backend/) kept in sync with deer-flow tree as a reference / source See HARDENING.md for the full audit trail and verification steps.
This commit is contained in:
@@ -0,0 +1,24 @@
|
||||
from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointer
|
||||
from .factory import create_deerflow_agent
|
||||
from .features import Next, Prev, RuntimeFeatures
|
||||
from .lead_agent import make_lead_agent
|
||||
from .lead_agent.prompt import prime_enabled_skills_cache
|
||||
from .thread_state import SandboxState, ThreadState
|
||||
|
||||
# LangGraph imports deerflow.agents when registering the graph. Prime the
|
||||
# enabled-skills cache here so the request path can usually read a warm cache
|
||||
# without forcing synchronous filesystem work during prompt module import.
|
||||
prime_enabled_skills_cache()
|
||||
|
||||
__all__ = [
|
||||
"create_deerflow_agent",
|
||||
"RuntimeFeatures",
|
||||
"Next",
|
||||
"Prev",
|
||||
"make_lead_agent",
|
||||
"SandboxState",
|
||||
"ThreadState",
|
||||
"get_checkpointer",
|
||||
"reset_checkpointer",
|
||||
"make_checkpointer",
|
||||
]
|
||||
@@ -0,0 +1,9 @@
|
||||
from .async_provider import make_checkpointer
|
||||
from .provider import checkpointer_context, get_checkpointer, reset_checkpointer
|
||||
|
||||
__all__ = [
|
||||
"get_checkpointer",
|
||||
"reset_checkpointer",
|
||||
"checkpointer_context",
|
||||
"make_checkpointer",
|
||||
]
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Async checkpointer factory.
|
||||
|
||||
Provides an **async context manager** for long-running async servers that need
|
||||
proper resource cleanup.
|
||||
|
||||
Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage (e.g. FastAPI lifespan)::
|
||||
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer # InMemorySaver if not configured
|
||||
|
||||
For sync usage see :mod:`deerflow.agents.checkpointer.provider`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.agents.checkpointer.provider import (
|
||||
POSTGRES_CONN_REQUIRED,
|
||||
POSTGRES_INSTALL,
|
||||
SQLITE_INSTALL,
|
||||
)
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that constructs and tears down a checkpointer."""
|
||||
if config.type == "memory":
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
await asyncio.to_thread(ensure_sqlite_parent_dir, conn_str)
|
||||
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public async context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that yields a checkpointer for the caller's lifetime.
|
||||
Resources are opened on enter and closed on exit — no global state::
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
|
||||
if config.checkpointer is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
async with _async_checkpointer(config.checkpointer) as saver:
|
||||
yield saver
|
||||
@@ -0,0 +1,191 @@
|
||||
"""Sync checkpointer factory.
|
||||
|
||||
Provides a **sync singleton** and a **sync context manager** for LangGraph
|
||||
graph compilation and CLI tools.
|
||||
|
||||
Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage::
|
||||
|
||||
from deerflow.agents.checkpointer.provider import get_checkpointer, checkpointer_context
|
||||
|
||||
# Singleton — reused across calls, closed on process exit
|
||||
cp = get_checkpointer()
|
||||
|
||||
# One-shot — fresh connection, closed on block exit
|
||||
with checkpointer_context() as cp:
|
||||
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
from deerflow.runtime.store._sqlite_utils import resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error message constants — imported by aio.provider too
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite"
|
||||
POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
|
||||
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
|
||||
"""Context manager that creates and tears down a sync checkpointer.
|
||||
|
||||
Returns a configured ``Checkpointer`` instance. Resource cleanup for any
|
||||
underlying connections or pools is handled by higher-level helpers in
|
||||
this module (such as the singleton factory or context manager); this
|
||||
function does not return a separate cleanup callback.
|
||||
"""
|
||||
if config.type == "memory":
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
with SqliteSaver.from_conn_string(conn_str) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
|
||||
yield saver
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres import PostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
with PostgresSaver.from_conn_string(config.connection_string) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using PostgresSaver")
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_checkpointer: Checkpointer | None = None
|
||||
_checkpointer_ctx = None # open context manager keeping the connection alive
|
||||
|
||||
|
||||
def get_checkpointer() -> Checkpointer:
|
||||
"""Return the global sync checkpointer singleton, creating it on first call.
|
||||
|
||||
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
|
||||
Raises:
|
||||
ImportError: If the required package for the configured backend is not installed.
|
||||
ValueError: If ``connection_string`` is missing for a backend that requires it.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_ctx
|
||||
|
||||
if _checkpointer is not None:
|
||||
return _checkpointer
|
||||
|
||||
# Ensure app config is loaded before checking checkpointer config
|
||||
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
|
||||
# but hasn't been loaded yet
|
||||
from deerflow.config.app_config import _app_config
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
config = get_checkpointer_config()
|
||||
|
||||
if config is None and _app_config is None:
|
||||
# Only load app config lazily when neither the app config nor an explicit
|
||||
# checkpointer config has been initialized yet. This keeps tests that
|
||||
# intentionally set the global checkpointer config isolated from any
|
||||
# ambient config.yaml on disk.
|
||||
try:
|
||||
get_app_config()
|
||||
except FileNotFoundError:
|
||||
# In test environments without config.yaml, this is expected.
|
||||
pass
|
||||
config = get_checkpointer_config()
|
||||
if config is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
_checkpointer = InMemorySaver()
|
||||
return _checkpointer
|
||||
|
||||
_checkpointer_ctx = _sync_checkpointer_cm(config)
|
||||
_checkpointer = _checkpointer_ctx.__enter__()
|
||||
|
||||
return _checkpointer
|
||||
|
||||
|
||||
def reset_checkpointer() -> None:
|
||||
"""Reset the sync singleton, forcing recreation on the next call.
|
||||
|
||||
Closes any open backend connections and clears the cached instance.
|
||||
Useful in tests or after a configuration change.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_ctx
|
||||
if _checkpointer_ctx is not None:
|
||||
try:
|
||||
_checkpointer_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error during checkpointer cleanup", exc_info=True)
|
||||
_checkpointer_ctx = None
|
||||
_checkpointer = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def checkpointer_context() -> Iterator[Checkpointer]:
|
||||
"""Sync context manager that yields a checkpointer and cleans up on exit.
|
||||
|
||||
Unlike :func:`get_checkpointer`, this does **not** cache the instance —
|
||||
each ``with`` block creates and destroys its own connection. Use it in
|
||||
CLI scripts or tests where you want deterministic cleanup::
|
||||
|
||||
with checkpointer_context() as cp:
|
||||
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
if config.checkpointer is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
with _sync_checkpointer_cm(config.checkpointer) as saver:
|
||||
yield saver
|
||||
372
deer-flow/backend/packages/harness/deerflow/agents/factory.py
Normal file
372
deer-flow/backend/packages/harness/deerflow/agents/factory.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""Pure-argument factory for DeerFlow agents.
|
||||
|
||||
``create_deerflow_agent`` accepts plain Python arguments — no YAML files, no
|
||||
global singletons. It is the SDK-level entry point sitting between the raw
|
||||
``langchain.agents.create_agent`` primitive and the config-driven
|
||||
``make_lead_agent`` application factory.
|
||||
|
||||
Note: the factory assembly itself is config-free, but some injected runtime
|
||||
components (e.g. ``task_tool`` for subagent) may still read global config at
|
||||
invocation time. Full config-free runtime is a Phase 2 goal.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
|
||||
from deerflow.agents.features import RuntimeFeatures
|
||||
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||
from deerflow.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
|
||||
from deerflow.agents.middlewares.tool_error_handling_middleware import ToolErrorHandlingMiddleware
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.tools.builtins import ask_clarification_tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TodoMiddleware prompts (minimal SDK version)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TODO_SYSTEM_PROMPT = """
|
||||
<todo_list_system>
|
||||
You have access to the `write_todos` tool to help you manage and track complex multi-step objectives.
|
||||
|
||||
**CRITICAL RULES:**
|
||||
- Mark todos as completed IMMEDIATELY after finishing each step - do NOT batch completions
|
||||
- Keep EXACTLY ONE task as `in_progress` at any time (unless tasks can run in parallel)
|
||||
- Update the todo list in REAL-TIME as you work - this gives users visibility into your progress
|
||||
- DO NOT use this tool for simple tasks (< 3 steps) - just complete them directly
|
||||
</todo_list_system>
|
||||
"""
|
||||
|
||||
_TODO_TOOL_DESCRIPTION = "Use this tool to create and manage a structured task list for complex work sessions. Only use for complex tasks (3+ steps)."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_deerflow_agent(
|
||||
model: BaseChatModel,
|
||||
tools: list[BaseTool] | None = None,
|
||||
*,
|
||||
system_prompt: str | None = None,
|
||||
middleware: list[AgentMiddleware] | None = None,
|
||||
features: RuntimeFeatures | None = None,
|
||||
extra_middleware: list[AgentMiddleware] | None = None,
|
||||
plan_mode: bool = False,
|
||||
state_schema: type | None = None,
|
||||
checkpointer: BaseCheckpointSaver | None = None,
|
||||
name: str = "default",
|
||||
) -> CompiledStateGraph:
|
||||
"""Create a DeerFlow agent from plain Python arguments.
|
||||
|
||||
The factory assembly itself reads no config files. Some injected runtime
|
||||
components (e.g. ``task_tool``) may still depend on global config at
|
||||
invocation time — see Phase 2 roadmap for full config-free runtime.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model:
|
||||
Chat model instance.
|
||||
tools:
|
||||
User-provided tools. Feature-injected tools are appended automatically.
|
||||
system_prompt:
|
||||
System message. ``None`` uses a minimal default.
|
||||
middleware:
|
||||
**Full takeover** — if provided, this exact list is used.
|
||||
Cannot be combined with *features* or *extra_middleware*.
|
||||
features:
|
||||
Declarative feature flags. Cannot be combined with *middleware*.
|
||||
extra_middleware:
|
||||
Additional middlewares inserted into the auto-assembled chain via
|
||||
``@Next``/``@Prev`` positioning. Cannot be used with *middleware*.
|
||||
plan_mode:
|
||||
Enable TodoMiddleware for task tracking.
|
||||
state_schema:
|
||||
LangGraph state type. Defaults to ``ThreadState``.
|
||||
checkpointer:
|
||||
Optional persistence backend.
|
||||
name:
|
||||
Agent name (passed to middleware that cares, e.g. ``MemoryMiddleware``).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If both *middleware* and *features*/*extra_middleware* are provided.
|
||||
"""
|
||||
if middleware is not None and features is not None:
|
||||
raise ValueError("Cannot specify both 'middleware' and 'features'. Use one or the other.")
|
||||
if middleware is not None and extra_middleware:
|
||||
raise ValueError("Cannot use 'extra_middleware' with 'middleware' (full takeover).")
|
||||
if extra_middleware:
|
||||
for mw in extra_middleware:
|
||||
if not isinstance(mw, AgentMiddleware):
|
||||
raise TypeError(f"extra_middleware items must be AgentMiddleware instances, got {type(mw).__name__}")
|
||||
|
||||
effective_tools: list[BaseTool] = list(tools or [])
|
||||
effective_state = state_schema or ThreadState
|
||||
|
||||
if middleware is not None:
|
||||
effective_middleware = list(middleware)
|
||||
else:
|
||||
feat = features or RuntimeFeatures()
|
||||
effective_middleware, extra_tools = _assemble_from_features(
|
||||
feat,
|
||||
name=name,
|
||||
plan_mode=plan_mode,
|
||||
extra_middleware=extra_middleware or [],
|
||||
)
|
||||
# Deduplicate by tool name — user-provided tools take priority.
|
||||
existing_names = {t.name for t in effective_tools}
|
||||
for t in extra_tools:
|
||||
if t.name not in existing_names:
|
||||
effective_tools.append(t)
|
||||
existing_names.add(t.name)
|
||||
|
||||
return create_agent(
|
||||
model=model,
|
||||
tools=effective_tools or None,
|
||||
middleware=effective_middleware,
|
||||
system_prompt=system_prompt,
|
||||
state_schema=effective_state,
|
||||
checkpointer=checkpointer,
|
||||
name=name,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal: feature-driven middleware assembly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _assemble_from_features(
|
||||
feat: RuntimeFeatures,
|
||||
*,
|
||||
name: str = "default",
|
||||
plan_mode: bool = False,
|
||||
extra_middleware: list[AgentMiddleware] | None = None,
|
||||
) -> tuple[list[AgentMiddleware], list[BaseTool]]:
|
||||
"""Build an ordered middleware chain + extra tools from *feat*.
|
||||
|
||||
Middleware order matches ``make_lead_agent`` (14 middlewares):
|
||||
|
||||
0-2. Sandbox infrastructure (ThreadData → Uploads → Sandbox)
|
||||
3. DanglingToolCallMiddleware (always)
|
||||
4. GuardrailMiddleware (guardrail feature)
|
||||
5. ToolErrorHandlingMiddleware (always)
|
||||
6. SummarizationMiddleware (summarization feature)
|
||||
7. TodoMiddleware (plan_mode parameter)
|
||||
8. TitleMiddleware (auto_title feature)
|
||||
9. MemoryMiddleware (memory feature)
|
||||
10. ViewImageMiddleware (vision feature)
|
||||
11. SubagentLimitMiddleware (subagent feature)
|
||||
12. LoopDetectionMiddleware (always)
|
||||
13. ClarificationMiddleware (always last)
|
||||
|
||||
Two-phase ordering:
|
||||
1. Built-in chain — fixed sequential append.
|
||||
2. Extra middleware — inserted via @Next/@Prev.
|
||||
|
||||
Each feature value is handled as:
|
||||
- ``False``: skip
|
||||
- ``True``: create the built-in default middleware (not available for
|
||||
``summarization`` and ``guardrail`` — these require a custom instance)
|
||||
- ``AgentMiddleware`` instance: use directly (custom replacement)
|
||||
"""
|
||||
chain: list[AgentMiddleware] = []
|
||||
extra_tools: list[BaseTool] = []
|
||||
|
||||
# --- [0-2] Sandbox infrastructure ---
|
||||
if feat.sandbox is not False:
|
||||
if isinstance(feat.sandbox, AgentMiddleware):
|
||||
chain.append(feat.sandbox)
|
||||
else:
|
||||
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
from deerflow.sandbox.middleware import SandboxMiddleware
|
||||
|
||||
chain.append(ThreadDataMiddleware(lazy_init=True))
|
||||
chain.append(UploadsMiddleware())
|
||||
chain.append(SandboxMiddleware(lazy_init=True))
|
||||
|
||||
# --- [3] DanglingToolCall (always) ---
|
||||
chain.append(DanglingToolCallMiddleware())
|
||||
|
||||
# --- [4] Guardrail ---
|
||||
if feat.guardrail is not False:
|
||||
if isinstance(feat.guardrail, AgentMiddleware):
|
||||
chain.append(feat.guardrail)
|
||||
else:
|
||||
raise ValueError("guardrail=True requires a custom AgentMiddleware instance (no built-in GuardrailMiddleware yet)")
|
||||
|
||||
# --- [5] ToolErrorHandling (always) ---
|
||||
chain.append(ToolErrorHandlingMiddleware())
|
||||
|
||||
# --- [6] Summarization ---
|
||||
if feat.summarization is not False:
|
||||
if isinstance(feat.summarization, AgentMiddleware):
|
||||
chain.append(feat.summarization)
|
||||
else:
|
||||
raise ValueError("summarization=True requires a custom AgentMiddleware instance (SummarizationMiddleware needs a model argument)")
|
||||
|
||||
# --- [7] TodoMiddleware (plan_mode) ---
|
||||
if plan_mode:
|
||||
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
||||
|
||||
chain.append(TodoMiddleware(system_prompt=_TODO_SYSTEM_PROMPT, tool_description=_TODO_TOOL_DESCRIPTION))
|
||||
|
||||
# --- [8] Auto Title ---
|
||||
if feat.auto_title is not False:
|
||||
if isinstance(feat.auto_title, AgentMiddleware):
|
||||
chain.append(feat.auto_title)
|
||||
else:
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
|
||||
chain.append(TitleMiddleware())
|
||||
|
||||
# --- [9] Memory ---
|
||||
if feat.memory is not False:
|
||||
if isinstance(feat.memory, AgentMiddleware):
|
||||
chain.append(feat.memory)
|
||||
else:
|
||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
|
||||
chain.append(MemoryMiddleware(agent_name=name))
|
||||
|
||||
# --- [10] Vision ---
|
||||
if feat.vision is not False:
|
||||
if isinstance(feat.vision, AgentMiddleware):
|
||||
chain.append(feat.vision)
|
||||
else:
|
||||
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
|
||||
chain.append(ViewImageMiddleware())
|
||||
from deerflow.tools.builtins import view_image_tool
|
||||
|
||||
extra_tools.append(view_image_tool)
|
||||
|
||||
# --- [11] Subagent ---
|
||||
if feat.subagent is not False:
|
||||
if isinstance(feat.subagent, AgentMiddleware):
|
||||
chain.append(feat.subagent)
|
||||
else:
|
||||
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||
|
||||
chain.append(SubagentLimitMiddleware())
|
||||
from deerflow.tools.builtins import task_tool
|
||||
|
||||
extra_tools.append(task_tool)
|
||||
|
||||
# --- [12] LoopDetection (always) ---
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
|
||||
chain.append(LoopDetectionMiddleware())
|
||||
|
||||
# --- [13] Clarification (always last among built-ins) ---
|
||||
chain.append(ClarificationMiddleware())
|
||||
extra_tools.append(ask_clarification_tool)
|
||||
|
||||
# --- Insert extra_middleware via @Next/@Prev ---
|
||||
if extra_middleware:
|
||||
_insert_extra(chain, extra_middleware)
|
||||
# Invariant: ClarificationMiddleware must always be last.
|
||||
# @Next(ClarificationMiddleware) could push it off the tail.
|
||||
clar_idx = next(i for i, m in enumerate(chain) if isinstance(m, ClarificationMiddleware))
|
||||
if clar_idx != len(chain) - 1:
|
||||
chain.append(chain.pop(clar_idx))
|
||||
|
||||
return chain, extra_tools
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal: extra middleware insertion with @Next/@Prev
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _insert_extra(chain: list[AgentMiddleware], extras: list[AgentMiddleware]) -> None:
|
||||
"""Insert extra middlewares into *chain* using ``@Next``/``@Prev`` anchors.
|
||||
|
||||
Algorithm:
|
||||
1. Validate: no middleware has both @Next and @Prev.
|
||||
2. Conflict detection: two extras targeting same anchor (same or opposite direction) → error.
|
||||
3. Insert unanchored extras before ClarificationMiddleware.
|
||||
4. Insert anchored extras iteratively (supports cross-external anchoring).
|
||||
5. If an anchor cannot be resolved after all rounds → error.
|
||||
"""
|
||||
next_targets: dict[type, type] = {}
|
||||
prev_targets: dict[type, type] = {}
|
||||
|
||||
anchored: list[tuple[AgentMiddleware, str, type]] = []
|
||||
unanchored: list[AgentMiddleware] = []
|
||||
|
||||
for mw in extras:
|
||||
next_anchor = getattr(type(mw), "_next_anchor", None)
|
||||
prev_anchor = getattr(type(mw), "_prev_anchor", None)
|
||||
|
||||
if next_anchor and prev_anchor:
|
||||
raise ValueError(f"{type(mw).__name__} cannot have both @Next and @Prev")
|
||||
|
||||
if next_anchor:
|
||||
if next_anchor in next_targets:
|
||||
raise ValueError(f"Conflict: {type(mw).__name__} and {next_targets[next_anchor].__name__} both @Next({next_anchor.__name__})")
|
||||
if next_anchor in prev_targets:
|
||||
raise ValueError(f"Conflict: {type(mw).__name__} @Next({next_anchor.__name__}) and {prev_targets[next_anchor].__name__} @Prev({next_anchor.__name__}) — use cross-anchoring between extras instead")
|
||||
next_targets[next_anchor] = type(mw)
|
||||
anchored.append((mw, "next", next_anchor))
|
||||
elif prev_anchor:
|
||||
if prev_anchor in prev_targets:
|
||||
raise ValueError(f"Conflict: {type(mw).__name__} and {prev_targets[prev_anchor].__name__} both @Prev({prev_anchor.__name__})")
|
||||
if prev_anchor in next_targets:
|
||||
raise ValueError(f"Conflict: {type(mw).__name__} @Prev({prev_anchor.__name__}) and {next_targets[prev_anchor].__name__} @Next({prev_anchor.__name__}) — use cross-anchoring between extras instead")
|
||||
prev_targets[prev_anchor] = type(mw)
|
||||
anchored.append((mw, "prev", prev_anchor))
|
||||
else:
|
||||
unanchored.append(mw)
|
||||
|
||||
# Unanchored → before ClarificationMiddleware
|
||||
clarification_idx = next(i for i, m in enumerate(chain) if isinstance(m, ClarificationMiddleware))
|
||||
for mw in unanchored:
|
||||
chain.insert(clarification_idx, mw)
|
||||
clarification_idx += 1
|
||||
|
||||
# Anchored → iterative insertion (supports external-to-external anchoring)
|
||||
pending = list(anchored)
|
||||
max_rounds = len(pending) + 1
|
||||
for _ in range(max_rounds):
|
||||
if not pending:
|
||||
break
|
||||
remaining = []
|
||||
for mw, direction, anchor in pending:
|
||||
idx = next(
|
||||
(i for i, m in enumerate(chain) if isinstance(m, anchor)),
|
||||
None,
|
||||
)
|
||||
if idx is None:
|
||||
remaining.append((mw, direction, anchor))
|
||||
continue
|
||||
if direction == "next":
|
||||
chain.insert(idx + 1, mw)
|
||||
else:
|
||||
chain.insert(idx, mw)
|
||||
if len(remaining) == len(pending):
|
||||
names = [type(m).__name__ for m, _, _ in remaining]
|
||||
anchor_types = {a for _, _, a in remaining}
|
||||
remaining_types = {type(m) for m, _, _ in remaining}
|
||||
circular = anchor_types & remaining_types
|
||||
if circular:
|
||||
raise ValueError(f"Circular dependency among extra middlewares: {', '.join(t.__name__ for t in circular)}")
|
||||
raise ValueError(f"Cannot resolve positions for {', '.join(names)} — anchors {', '.join(a.__name__ for _, _, a in remaining)} not found in chain")
|
||||
pending = remaining
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Declarative feature flags and middleware positioning for create_deerflow_agent.
|
||||
|
||||
Pure data classes and decorators — no I/O, no side effects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeFeatures:
|
||||
"""Declarative feature flags for ``create_deerflow_agent``.
|
||||
|
||||
Most features accept:
|
||||
- ``True``: use the built-in default middleware
|
||||
- ``False``: disable
|
||||
- An ``AgentMiddleware`` instance: use this custom implementation instead
|
||||
|
||||
``summarization`` and ``guardrail`` have no built-in default — they only
|
||||
accept ``False`` (disable) or an ``AgentMiddleware`` instance (custom).
|
||||
"""
|
||||
|
||||
sandbox: bool | AgentMiddleware = True
|
||||
memory: bool | AgentMiddleware = False
|
||||
summarization: Literal[False] | AgentMiddleware = False
|
||||
subagent: bool | AgentMiddleware = False
|
||||
vision: bool | AgentMiddleware = False
|
||||
auto_title: bool | AgentMiddleware = False
|
||||
guardrail: Literal[False] | AgentMiddleware = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware positioning decorators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def Next(anchor: type[AgentMiddleware]):
|
||||
"""Declare this middleware should be placed after *anchor* in the chain."""
|
||||
if not (isinstance(anchor, type) and issubclass(anchor, AgentMiddleware)):
|
||||
raise TypeError(f"@Next expects an AgentMiddleware subclass, got {anchor!r}")
|
||||
|
||||
def decorator(cls: type[AgentMiddleware]) -> type[AgentMiddleware]:
|
||||
cls._next_anchor = anchor # type: ignore[attr-defined]
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def Prev(anchor: type[AgentMiddleware]):
|
||||
"""Declare this middleware should be placed before *anchor* in the chain."""
|
||||
if not (isinstance(anchor, type) and issubclass(anchor, AgentMiddleware)):
|
||||
raise TypeError(f"@Prev expects an AgentMiddleware subclass, got {anchor!r}")
|
||||
|
||||
def decorator(cls: type[AgentMiddleware]) -> type[AgentMiddleware]:
|
||||
cls._prev_anchor = anchor # type: ignore[attr-defined]
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,3 @@
|
||||
from .agent import make_lead_agent
|
||||
|
||||
__all__ = ["make_lead_agent"]
|
||||
@@ -0,0 +1,350 @@
|
||||
import logging
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware, SummarizationMiddleware
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||
from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||
from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from deerflow.agents.middlewares.todo_middleware import TodoMiddleware
|
||||
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
||||
from deerflow.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
|
||||
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.agents_config import load_agent_config
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.summarization_config import get_summarization_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_model_name(requested_model_name: str | None = None) -> str:
|
||||
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
|
||||
app_config = get_app_config()
|
||||
default_model_name = app_config.models[0].name if app_config.models else None
|
||||
if default_model_name is None:
|
||||
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
|
||||
|
||||
if requested_model_name and app_config.get_model_config(requested_model_name):
|
||||
return requested_model_name
|
||||
|
||||
if requested_model_name and requested_model_name != default_model_name:
|
||||
logger.warning(f"Model '{requested_model_name}' not found in config; fallback to default model '{default_model_name}'.")
|
||||
return default_model_name
|
||||
|
||||
|
||||
def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
||||
"""Create and configure the summarization middleware from config."""
|
||||
config = get_summarization_config()
|
||||
|
||||
if not config.enabled:
|
||||
return None
|
||||
|
||||
# Prepare trigger parameter
|
||||
trigger = None
|
||||
if config.trigger is not None:
|
||||
if isinstance(config.trigger, list):
|
||||
trigger = [t.to_tuple() for t in config.trigger]
|
||||
else:
|
||||
trigger = config.trigger.to_tuple()
|
||||
|
||||
# Prepare keep parameter
|
||||
keep = config.keep.to_tuple()
|
||||
|
||||
# Prepare model parameter
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
else:
|
||||
# Use a lightweight model for summarization to save costs
|
||||
# Falls back to default model if not explicitly specified
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
|
||||
# Prepare kwargs
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"trigger": trigger,
|
||||
"keep": keep,
|
||||
}
|
||||
|
||||
if config.trim_tokens_to_summarize is not None:
|
||||
kwargs["trim_tokens_to_summarize"] = config.trim_tokens_to_summarize
|
||||
|
||||
if config.summary_prompt is not None:
|
||||
kwargs["summary_prompt"] = config.summary_prompt
|
||||
|
||||
return SummarizationMiddleware(**kwargs)
|
||||
|
||||
|
||||
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
|
||||
"""Create and configure the TodoList middleware.
|
||||
|
||||
Args:
|
||||
is_plan_mode: Whether to enable plan mode with TodoList middleware.
|
||||
|
||||
Returns:
|
||||
TodoMiddleware instance if plan mode is enabled, None otherwise.
|
||||
"""
|
||||
if not is_plan_mode:
|
||||
return None
|
||||
|
||||
# Custom prompts matching DeerFlow's style
|
||||
system_prompt = """
|
||||
<todo_list_system>
|
||||
You have access to the `write_todos` tool to help you manage and track complex multi-step objectives.
|
||||
|
||||
**CRITICAL RULES:**
|
||||
- Mark todos as completed IMMEDIATELY after finishing each step - do NOT batch completions
|
||||
- Keep EXACTLY ONE task as `in_progress` at any time (unless tasks can run in parallel)
|
||||
- Update the todo list in REAL-TIME as you work - this gives users visibility into your progress
|
||||
- DO NOT use this tool for simple tasks (< 3 steps) - just complete them directly
|
||||
|
||||
**When to Use:**
|
||||
This tool is designed for complex objectives that require systematic tracking:
|
||||
- Complex multi-step tasks requiring 3+ distinct steps
|
||||
- Non-trivial tasks needing careful planning and execution
|
||||
- User explicitly requests a todo list
|
||||
- User provides multiple tasks (numbered or comma-separated list)
|
||||
- The plan may need revisions based on intermediate results
|
||||
|
||||
**When NOT to Use:**
|
||||
- Single, straightforward tasks
|
||||
- Trivial tasks (< 3 steps)
|
||||
- Purely conversational or informational requests
|
||||
- Simple tool calls where the approach is obvious
|
||||
|
||||
**Best Practices:**
|
||||
- Break down complex tasks into smaller, actionable steps
|
||||
- Use clear, descriptive task names
|
||||
- Remove tasks that become irrelevant
|
||||
- Add new tasks discovered during implementation
|
||||
- Don't be afraid to revise the todo list as you learn more
|
||||
|
||||
**Task Management:**
|
||||
Writing todos takes time and tokens - use it when helpful for managing complex problems, not for simple requests.
|
||||
</todo_list_system>
|
||||
"""
|
||||
|
||||
tool_description = """Use this tool to create and manage a structured task list for complex work sessions.
|
||||
|
||||
**IMPORTANT: Only use this tool for complex tasks (3+ steps). For simple requests, just do the work directly.**
|
||||
|
||||
## When to Use
|
||||
|
||||
Use this tool in these scenarios:
|
||||
1. **Complex multi-step tasks**: When a task requires 3 or more distinct steps or actions
|
||||
2. **Non-trivial tasks**: Tasks requiring careful planning or multiple operations
|
||||
3. **User explicitly requests todo list**: When the user directly asks you to track tasks
|
||||
4. **Multiple tasks**: When users provide a list of things to be done
|
||||
5. **Dynamic planning**: When the plan may need updates based on intermediate results
|
||||
|
||||
## When NOT to Use
|
||||
|
||||
Skip this tool when:
|
||||
1. The task is straightforward and takes less than 3 steps
|
||||
2. The task is trivial and tracking provides no benefit
|
||||
3. The task is purely conversational or informational
|
||||
4. It's clear what needs to be done and you can just do it
|
||||
|
||||
## How to Use
|
||||
|
||||
1. **Starting a task**: Mark it as `in_progress` BEFORE beginning work
|
||||
2. **Completing a task**: Mark it as `completed` IMMEDIATELY after finishing
|
||||
3. **Updating the list**: Add new tasks, remove irrelevant ones, or update descriptions as needed
|
||||
4. **Multiple updates**: You can make several updates at once (e.g., complete one task and start the next)
|
||||
|
||||
## Task States
|
||||
|
||||
- `pending`: Task not yet started
|
||||
- `in_progress`: Currently working on (can have multiple if tasks run in parallel)
|
||||
- `completed`: Task finished successfully
|
||||
|
||||
## Task Completion Requirements
|
||||
|
||||
**CRITICAL: Only mark a task as completed when you have FULLY accomplished it.**
|
||||
|
||||
Never mark a task as completed if:
|
||||
- There are unresolved issues or errors
|
||||
- Work is partial or incomplete
|
||||
- You encountered blockers preventing completion
|
||||
- You couldn't find necessary resources or dependencies
|
||||
- Quality standards haven't been met
|
||||
|
||||
If blocked, keep the task as `in_progress` and create a new task describing what needs to be resolved.
|
||||
|
||||
## Best Practices
|
||||
|
||||
- Create specific, actionable items
|
||||
- Break complex tasks into smaller, manageable steps
|
||||
- Use clear, descriptive task names
|
||||
- Update task status in real-time as you work
|
||||
- Mark tasks complete IMMEDIATELY after finishing (don't batch completions)
|
||||
- Remove tasks that are no longer relevant
|
||||
- **IMPORTANT**: When you write the todo list, mark your first task(s) as `in_progress` immediately
|
||||
- **IMPORTANT**: Unless all tasks are completed, always have at least one task `in_progress` to show progress
|
||||
|
||||
Being proactive with task management demonstrates thoroughness and ensures all requirements are completed successfully.
|
||||
|
||||
**Remember**: If you only need a few tool calls to complete a task and it's clear what to do, it's better to just do the task directly and NOT use this tool at all.
|
||||
"""
|
||||
|
||||
return TodoMiddleware(system_prompt=system_prompt, tool_description=tool_description)
|
||||
|
||||
|
||||
# ThreadDataMiddleware must be before SandboxMiddleware to ensure thread_id is available
|
||||
# UploadsMiddleware should be after ThreadDataMiddleware to access thread_id
|
||||
# DanglingToolCallMiddleware patches missing ToolMessages before model sees the history
|
||||
# SummarizationMiddleware should be early to reduce context before other processing
|
||||
# TodoListMiddleware should be before ClarificationMiddleware to allow todo management
|
||||
# TitleMiddleware generates title after first exchange
|
||||
# MemoryMiddleware queues conversation for memory update (after TitleMiddleware)
|
||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
|
||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None):
|
||||
"""Build middleware chain based on runtime configuration.
|
||||
|
||||
Args:
|
||||
config: Runtime configuration containing configurable options like is_plan_mode.
|
||||
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
|
||||
custom_middlewares: Optional list of custom middlewares to inject into the chain.
|
||||
|
||||
Returns:
|
||||
List of middleware instances.
|
||||
"""
|
||||
middlewares = build_lead_runtime_middlewares(lazy_init=True)
|
||||
|
||||
# Add summarization middleware if enabled
|
||||
summarization_middleware = _create_summarization_middleware()
|
||||
if summarization_middleware is not None:
|
||||
middlewares.append(summarization_middleware)
|
||||
|
||||
# Add TodoList middleware if plan mode is enabled
|
||||
is_plan_mode = config.get("configurable", {}).get("is_plan_mode", False)
|
||||
todo_list_middleware = _create_todo_list_middleware(is_plan_mode)
|
||||
if todo_list_middleware is not None:
|
||||
middlewares.append(todo_list_middleware)
|
||||
|
||||
# Add TokenUsageMiddleware when token_usage tracking is enabled
|
||||
if get_app_config().token_usage.enabled:
|
||||
middlewares.append(TokenUsageMiddleware())
|
||||
|
||||
# Add TitleMiddleware
|
||||
middlewares.append(TitleMiddleware())
|
||||
|
||||
# Add MemoryMiddleware (after TitleMiddleware)
|
||||
middlewares.append(MemoryMiddleware(agent_name=agent_name))
|
||||
|
||||
# Add ViewImageMiddleware only if the current model supports vision.
|
||||
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
|
||||
app_config = get_app_config()
|
||||
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||
if model_config is not None and model_config.supports_vision:
|
||||
middlewares.append(ViewImageMiddleware())
|
||||
|
||||
# Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding
|
||||
if app_config.tool_search.enabled:
|
||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
||||
|
||||
middlewares.append(DeferredToolFilterMiddleware())
|
||||
|
||||
# Add SubagentLimitMiddleware to truncate excess parallel task calls
|
||||
subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False)
|
||||
if subagent_enabled:
|
||||
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3)
|
||||
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
|
||||
|
||||
# LoopDetectionMiddleware — detect and break repetitive tool call loops
|
||||
middlewares.append(LoopDetectionMiddleware())
|
||||
|
||||
# Inject custom middlewares before ClarificationMiddleware
|
||||
if custom_middlewares:
|
||||
middlewares.extend(custom_middlewares)
|
||||
|
||||
# ClarificationMiddleware should always be last
|
||||
middlewares.append(ClarificationMiddleware())
|
||||
return middlewares
|
||||
|
||||
|
||||
def make_lead_agent(config: RunnableConfig):
|
||||
# Lazy import to avoid circular dependency
|
||||
from deerflow.tools import get_available_tools
|
||||
from deerflow.tools.builtins import setup_agent
|
||||
|
||||
cfg = config.get("configurable", {})
|
||||
|
||||
thinking_enabled = cfg.get("thinking_enabled", True)
|
||||
reasoning_effort = cfg.get("reasoning_effort", None)
|
||||
requested_model_name: str | None = cfg.get("model_name") or cfg.get("model")
|
||||
is_plan_mode = cfg.get("is_plan_mode", False)
|
||||
subagent_enabled = cfg.get("subagent_enabled", False)
|
||||
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
|
||||
is_bootstrap = cfg.get("is_bootstrap", False)
|
||||
agent_name = cfg.get("agent_name")
|
||||
|
||||
agent_config = load_agent_config(agent_name) if not is_bootstrap else None
|
||||
# Custom agent model from agent config (if any), or None to let _resolve_model_name pick the default
|
||||
agent_model_name = agent_config.model if agent_config and agent_config.model else None
|
||||
|
||||
# Final model name resolution: request → agent config → global default, with fallback for unknown names
|
||||
model_name = _resolve_model_name(requested_model_name or agent_model_name)
|
||||
|
||||
app_config = get_app_config()
|
||||
model_config = app_config.get_model_config(model_name)
|
||||
|
||||
if model_config is None:
|
||||
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
|
||||
if thinking_enabled and not model_config.supports_thinking:
|
||||
logger.warning(f"Thinking mode is enabled but model '{model_name}' does not support it; fallback to non-thinking mode.")
|
||||
thinking_enabled = False
|
||||
|
||||
logger.info(
|
||||
"Create Agent(%s) -> thinking_enabled: %s, reasoning_effort: %s, model_name: %s, is_plan_mode: %s, subagent_enabled: %s, max_concurrent_subagents: %s",
|
||||
agent_name or "default",
|
||||
thinking_enabled,
|
||||
reasoning_effort,
|
||||
model_name,
|
||||
is_plan_mode,
|
||||
subagent_enabled,
|
||||
max_concurrent_subagents,
|
||||
)
|
||||
|
||||
# Inject run metadata for LangSmith trace tagging
|
||||
if "metadata" not in config:
|
||||
config["metadata"] = {}
|
||||
|
||||
config["metadata"].update(
|
||||
{
|
||||
"agent_name": agent_name or "default",
|
||||
"model_name": model_name or "default",
|
||||
"thinking_enabled": thinking_enabled,
|
||||
"reasoning_effort": reasoning_effort,
|
||||
"is_plan_mode": is_plan_mode,
|
||||
"subagent_enabled": subagent_enabled,
|
||||
}
|
||||
)
|
||||
|
||||
if is_bootstrap:
|
||||
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent],
|
||||
middleware=_build_middlewares(config, model_name=model_name),
|
||||
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
# Default lead agent (unchanged behavior)
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
|
||||
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
|
||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
|
||||
),
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
@@ -0,0 +1,727 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
|
||||
from deerflow.config.agents_config import load_agent_soul
|
||||
from deerflow.skills import load_skills
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.subagents import get_available_subagent_names
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
|
||||
_enabled_skills_lock = threading.Lock()
|
||||
_enabled_skills_cache: list[Skill] | None = None
|
||||
_enabled_skills_refresh_active = False
|
||||
_enabled_skills_refresh_version = 0
|
||||
_enabled_skills_refresh_event = threading.Event()
|
||||
|
||||
|
||||
def _load_enabled_skills_sync() -> list[Skill]:
|
||||
return list(load_skills(enabled_only=True))
|
||||
|
||||
|
||||
def _start_enabled_skills_refresh_thread() -> None:
|
||||
threading.Thread(
|
||||
target=_refresh_enabled_skills_cache_worker,
|
||||
name="deerflow-enabled-skills-loader",
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
|
||||
def _refresh_enabled_skills_cache_worker() -> None:
|
||||
global _enabled_skills_cache, _enabled_skills_refresh_active
|
||||
|
||||
while True:
|
||||
with _enabled_skills_lock:
|
||||
target_version = _enabled_skills_refresh_version
|
||||
|
||||
try:
|
||||
skills = _load_enabled_skills_sync()
|
||||
except Exception:
|
||||
logger.exception("Failed to load enabled skills for prompt injection")
|
||||
skills = []
|
||||
|
||||
with _enabled_skills_lock:
|
||||
if _enabled_skills_refresh_version == target_version:
|
||||
_enabled_skills_cache = skills
|
||||
_enabled_skills_refresh_active = False
|
||||
_enabled_skills_refresh_event.set()
|
||||
return
|
||||
|
||||
# A newer invalidation happened while loading. Keep the worker alive
|
||||
# and loop again so the cache always converges on the latest version.
|
||||
_enabled_skills_cache = None
|
||||
|
||||
|
||||
def _ensure_enabled_skills_cache() -> threading.Event:
|
||||
global _enabled_skills_refresh_active
|
||||
|
||||
with _enabled_skills_lock:
|
||||
if _enabled_skills_cache is not None:
|
||||
_enabled_skills_refresh_event.set()
|
||||
return _enabled_skills_refresh_event
|
||||
if _enabled_skills_refresh_active:
|
||||
return _enabled_skills_refresh_event
|
||||
_enabled_skills_refresh_active = True
|
||||
_enabled_skills_refresh_event.clear()
|
||||
|
||||
_start_enabled_skills_refresh_thread()
|
||||
return _enabled_skills_refresh_event
|
||||
|
||||
|
||||
def _invalidate_enabled_skills_cache() -> threading.Event:
|
||||
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
|
||||
|
||||
_get_cached_skills_prompt_section.cache_clear()
|
||||
with _enabled_skills_lock:
|
||||
_enabled_skills_cache = None
|
||||
_enabled_skills_refresh_version += 1
|
||||
_enabled_skills_refresh_event.clear()
|
||||
if _enabled_skills_refresh_active:
|
||||
return _enabled_skills_refresh_event
|
||||
_enabled_skills_refresh_active = True
|
||||
|
||||
_start_enabled_skills_refresh_thread()
|
||||
return _enabled_skills_refresh_event
|
||||
|
||||
|
||||
def prime_enabled_skills_cache() -> None:
|
||||
_ensure_enabled_skills_cache()
|
||||
|
||||
|
||||
def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool:
|
||||
if _ensure_enabled_skills_cache().wait(timeout=timeout_seconds):
|
||||
return True
|
||||
|
||||
logger.warning("Timed out waiting %.1fs for enabled skills cache warm-up", timeout_seconds)
|
||||
return False
|
||||
|
||||
|
||||
def _get_enabled_skills():
|
||||
with _enabled_skills_lock:
|
||||
cached = _enabled_skills_cache
|
||||
|
||||
if cached is not None:
|
||||
return list(cached)
|
||||
|
||||
_ensure_enabled_skills_cache()
|
||||
return []
|
||||
|
||||
|
||||
def _skill_mutability_label(category: str) -> str:
|
||||
return "[custom, editable]" if category == "custom" else "[built-in]"
|
||||
|
||||
|
||||
def clear_skills_system_prompt_cache() -> None:
|
||||
_invalidate_enabled_skills_cache()
|
||||
|
||||
|
||||
async def refresh_skills_system_prompt_cache_async() -> None:
|
||||
await asyncio.to_thread(_invalidate_enabled_skills_cache().wait)
|
||||
|
||||
|
||||
def _reset_skills_system_prompt_cache_state() -> None:
|
||||
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
|
||||
|
||||
_get_cached_skills_prompt_section.cache_clear()
|
||||
with _enabled_skills_lock:
|
||||
_enabled_skills_cache = None
|
||||
_enabled_skills_refresh_active = False
|
||||
_enabled_skills_refresh_version = 0
|
||||
_enabled_skills_refresh_event.clear()
|
||||
|
||||
|
||||
def _refresh_enabled_skills_cache() -> None:
|
||||
"""Backward-compatible test helper for direct synchronous reload."""
|
||||
try:
|
||||
skills = _load_enabled_skills_sync()
|
||||
except Exception:
|
||||
logger.exception("Failed to load enabled skills for prompt injection")
|
||||
skills = []
|
||||
|
||||
with _enabled_skills_lock:
|
||||
_enabled_skills_cache = skills
|
||||
_enabled_skills_refresh_active = False
|
||||
_enabled_skills_refresh_event.set()
|
||||
|
||||
|
||||
def _build_skill_evolution_section(skill_evolution_enabled: bool) -> str:
|
||||
if not skill_evolution_enabled:
|
||||
return ""
|
||||
return """
|
||||
## Skill Self-Evolution
|
||||
After completing a task, consider creating or updating a skill when:
|
||||
- The task required 5+ tool calls to resolve
|
||||
- You overcame non-obvious errors or pitfalls
|
||||
- The user corrected your approach and the corrected version worked
|
||||
- You discovered a non-trivial, recurring workflow
|
||||
If you used a skill and encountered issues not covered by it, patch it immediately.
|
||||
Prefer patch over edit. Before creating a new skill, confirm with the user first.
|
||||
Skip simple one-off tasks.
|
||||
"""
|
||||
|
||||
|
||||
def _build_subagent_section(max_concurrent: int) -> str:
|
||||
"""Build the subagent system prompt section with dynamic concurrency limit.
|
||||
|
||||
Args:
|
||||
max_concurrent: Maximum number of concurrent subagent calls allowed per response.
|
||||
|
||||
Returns:
|
||||
Formatted subagent section string.
|
||||
"""
|
||||
n = max_concurrent
|
||||
bash_available = "bash" in get_available_subagent_names()
|
||||
available_subagents = (
|
||||
"- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n- **bash**: For command execution (git, build, test, deploy operations)"
|
||||
if bash_available
|
||||
else "- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n"
|
||||
"- **bash**: Not available in the current sandbox configuration. Use direct file/web tools or switch to AioSandboxProvider for isolated shell access."
|
||||
)
|
||||
direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc."
|
||||
direct_execution_example = (
|
||||
'# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()'
|
||||
if bash_available
|
||||
else '# User asks: "Read the README"\n# Thinking: Single straightforward file read\n# → Execute directly\n\nread_file("/mnt/user-data/workspace/README.md") # Direct execution, not task()'
|
||||
)
|
||||
return f"""<subagent_system>
|
||||
**🚀 SUBAGENT MODE ACTIVE - DECOMPOSE, DELEGATE, SYNTHESIZE**
|
||||
|
||||
You are running with subagent capabilities enabled. Your role is to be a **task orchestrator**:
|
||||
1. **DECOMPOSE**: Break complex tasks into parallel sub-tasks
|
||||
2. **DELEGATE**: Launch multiple subagents simultaneously using parallel `task` calls
|
||||
3. **SYNTHESIZE**: Collect and integrate results into a coherent answer
|
||||
|
||||
**CORE PRINCIPLE: Complex tasks should be decomposed and distributed across multiple subagents for parallel execution.**
|
||||
|
||||
**⛔ HARD CONCURRENCY LIMIT: MAXIMUM {n} `task` CALLS PER RESPONSE. THIS IS NOT OPTIONAL.**
|
||||
- Each response, you may include **at most {n}** `task` tool calls. Any excess calls are **silently discarded** by the system — you will lose that work.
|
||||
- **Before launching subagents, you MUST count your sub-tasks in your thinking:**
|
||||
- If count ≤ {n}: Launch all in this response.
|
||||
- If count > {n}: **Pick the {n} most important/foundational sub-tasks for this turn.** Save the rest for the next turn.
|
||||
- **Multi-batch execution** (for >{n} sub-tasks):
|
||||
- Turn 1: Launch sub-tasks 1-{n} in parallel → wait for results
|
||||
- Turn 2: Launch next batch in parallel → wait for results
|
||||
- ... continue until all sub-tasks are complete
|
||||
- Final turn: Synthesize ALL results into a coherent answer
|
||||
- **Example thinking pattern**: "I identified 6 sub-tasks. Since the limit is {n} per turn, I will launch the first {n} now, and the rest in the next turn."
|
||||
|
||||
**Available Subagents:**
|
||||
{available_subagents}
|
||||
|
||||
**Your Orchestration Strategy:**
|
||||
|
||||
✅ **DECOMPOSE + PARALLEL EXECUTION (Preferred Approach):**
|
||||
|
||||
For complex queries, break them down into focused sub-tasks and execute in parallel batches (max {n} per turn):
|
||||
|
||||
**Example 1: "Why is Tencent's stock price declining?" (3 sub-tasks → 1 batch)**
|
||||
→ Turn 1: Launch 3 subagents in parallel:
|
||||
- Subagent 1: Recent financial reports, earnings data, and revenue trends
|
||||
- Subagent 2: Negative news, controversies, and regulatory issues
|
||||
- Subagent 3: Industry trends, competitor performance, and market sentiment
|
||||
→ Turn 2: Synthesize results
|
||||
|
||||
**Example 2: "Compare 5 cloud providers" (5 sub-tasks → multi-batch)**
|
||||
→ Turn 1: Launch {n} subagents in parallel (first batch)
|
||||
→ Turn 2: Launch remaining subagents in parallel
|
||||
→ Final turn: Synthesize ALL results into comprehensive comparison
|
||||
|
||||
**Example 3: "Refactor the authentication system"**
|
||||
→ Turn 1: Launch 3 subagents in parallel:
|
||||
- Subagent 1: Analyze current auth implementation and technical debt
|
||||
- Subagent 2: Research best practices and security patterns
|
||||
- Subagent 3: Review related tests, documentation, and vulnerabilities
|
||||
→ Turn 2: Synthesize results
|
||||
|
||||
✅ **USE Parallel Subagents (max {n} per turn) when:**
|
||||
- **Complex research questions**: Requires multiple information sources or perspectives
|
||||
- **Multi-aspect analysis**: Task has several independent dimensions to explore
|
||||
- **Large codebases**: Need to analyze different parts simultaneously
|
||||
- **Comprehensive investigations**: Questions requiring thorough coverage from multiple angles
|
||||
|
||||
❌ **DO NOT use subagents (execute directly) when:**
|
||||
- **Task cannot be decomposed**: If you can't break it into 2+ meaningful parallel sub-tasks, execute directly
|
||||
- **Ultra-simple actions**: Read one file, quick edits, single commands
|
||||
- **Need immediate clarification**: Must ask user before proceeding
|
||||
- **Meta conversation**: Questions about conversation history
|
||||
- **Sequential dependencies**: Each step depends on previous results (do steps yourself sequentially)
|
||||
|
||||
**CRITICAL WORKFLOW** (STRICTLY follow this before EVERY action):
|
||||
1. **COUNT**: In your thinking, list all sub-tasks and count them explicitly: "I have N sub-tasks"
|
||||
2. **PLAN BATCHES**: If N > {n}, explicitly plan which sub-tasks go in which batch:
|
||||
- "Batch 1 (this turn): first {n} sub-tasks"
|
||||
- "Batch 2 (next turn): next batch of sub-tasks"
|
||||
3. **EXECUTE**: Launch ONLY the current batch (max {n} `task` calls). Do NOT launch sub-tasks from future batches.
|
||||
4. **REPEAT**: After results return, launch the next batch. Continue until all batches complete.
|
||||
5. **SYNTHESIZE**: After ALL batches are done, synthesize all results.
|
||||
6. **Cannot decompose** → Execute directly using available tools ({direct_tool_examples})
|
||||
|
||||
**⛔ VIOLATION: Launching more than {n} `task` calls in a single response is a HARD ERROR. The system WILL discard excess calls and you WILL lose work. Always batch.**
|
||||
|
||||
**Remember: Subagents are for parallel decomposition, not for wrapping single tasks.**
|
||||
|
||||
**How It Works:**
|
||||
- The task tool runs subagents asynchronously in the background
|
||||
- The backend automatically polls for completion (you don't need to poll)
|
||||
- The tool call will block until the subagent completes its work
|
||||
- Once complete, the result is returned to you directly
|
||||
|
||||
**Usage Example 1 - Single Batch (≤{n} sub-tasks):**
|
||||
|
||||
```python
|
||||
# User asks: "Why is Tencent's stock price declining?"
|
||||
# Thinking: 3 sub-tasks → fits in 1 batch
|
||||
|
||||
# Turn 1: Launch 3 subagents in parallel
|
||||
task(description="Tencent financial data", prompt="...", subagent_type="general-purpose")
|
||||
task(description="Tencent news & regulation", prompt="...", subagent_type="general-purpose")
|
||||
task(description="Industry & market trends", prompt="...", subagent_type="general-purpose")
|
||||
# All 3 run in parallel → synthesize results
|
||||
```
|
||||
|
||||
**Usage Example 2 - Multiple Batches (>{n} sub-tasks):**
|
||||
|
||||
```python
|
||||
# User asks: "Compare AWS, Azure, GCP, Alibaba Cloud, and Oracle Cloud"
|
||||
# Thinking: 5 sub-tasks → need multiple batches (max {n} per batch)
|
||||
|
||||
# Turn 1: Launch first batch of {n}
|
||||
task(description="AWS analysis", prompt="...", subagent_type="general-purpose")
|
||||
task(description="Azure analysis", prompt="...", subagent_type="general-purpose")
|
||||
task(description="GCP analysis", prompt="...", subagent_type="general-purpose")
|
||||
|
||||
# Turn 2: Launch remaining batch (after first batch completes)
|
||||
task(description="Alibaba Cloud analysis", prompt="...", subagent_type="general-purpose")
|
||||
task(description="Oracle Cloud analysis", prompt="...", subagent_type="general-purpose")
|
||||
|
||||
# Turn 3: Synthesize ALL results from both batches
|
||||
```
|
||||
|
||||
**Counter-Example - Direct Execution (NO subagents):**
|
||||
|
||||
```python
|
||||
{direct_execution_example}
|
||||
```
|
||||
|
||||
**CRITICAL**:
|
||||
- **Max {n} `task` calls per turn** - the system enforces this, excess calls are discarded
|
||||
- Only use `task` when you can launch 2+ subagents in parallel
|
||||
- Single task = No value from subagents = Execute directly
|
||||
- For >{n} sub-tasks, use sequential batches of {n} across multiple turns
|
||||
</subagent_system>"""
|
||||
|
||||
|
||||
SYSTEM_PROMPT_TEMPLATE = """
|
||||
<role>
|
||||
You are {agent_name}, an open-source super agent.
|
||||
</role>
|
||||
|
||||
{soul}
|
||||
{memory_context}
|
||||
|
||||
<thinking_style>
|
||||
- Think concisely and strategically about the user's request BEFORE taking action
|
||||
- Break down the task: What is clear? What is ambiguous? What is missing?
|
||||
- **PRIORITY CHECK: If anything is unclear, missing, or has multiple interpretations, you MUST ask for clarification FIRST - do NOT proceed with work**
|
||||
{subagent_thinking}- Never write down your full final answer or report in thinking process, but only outline
|
||||
- CRITICAL: After thinking, you MUST provide your actual response to the user. Thinking is for planning, the response is for delivery.
|
||||
- Your response must contain the actual answer, not just a reference to what you thought about
|
||||
</thinking_style>
|
||||
|
||||
<clarification_system>
|
||||
**WORKFLOW PRIORITY: CLARIFY → PLAN → ACT**
|
||||
1. **FIRST**: Analyze the request in your thinking - identify what's unclear, missing, or ambiguous
|
||||
2. **SECOND**: If clarification is needed, call `ask_clarification` tool IMMEDIATELY - do NOT start working
|
||||
3. **THIRD**: Only after all clarifications are resolved, proceed with planning and execution
|
||||
|
||||
**CRITICAL RULE: Clarification ALWAYS comes BEFORE action. Never start working and clarify mid-execution.**
|
||||
|
||||
**MANDATORY Clarification Scenarios - You MUST call ask_clarification BEFORE starting work when:**
|
||||
|
||||
1. **Missing Information** (`missing_info`): Required details not provided
|
||||
- Example: User says "create a web scraper" but doesn't specify the target website
|
||||
- Example: "Deploy the app" without specifying environment
|
||||
- **REQUIRED ACTION**: Call ask_clarification to get the missing information
|
||||
|
||||
2. **Ambiguous Requirements** (`ambiguous_requirement`): Multiple valid interpretations exist
|
||||
- Example: "Optimize the code" could mean performance, readability, or memory usage
|
||||
- Example: "Make it better" is unclear what aspect to improve
|
||||
- **REQUIRED ACTION**: Call ask_clarification to clarify the exact requirement
|
||||
|
||||
3. **Approach Choices** (`approach_choice`): Several valid approaches exist
|
||||
- Example: "Add authentication" could use JWT, OAuth, session-based, or API keys
|
||||
- Example: "Store data" could use database, files, cache, etc.
|
||||
- **REQUIRED ACTION**: Call ask_clarification to let user choose the approach
|
||||
|
||||
4. **Risky Operations** (`risk_confirmation`): Destructive actions need confirmation
|
||||
- Example: Deleting files, modifying production configs, database operations
|
||||
- Example: Overwriting existing code or data
|
||||
- **REQUIRED ACTION**: Call ask_clarification to get explicit confirmation
|
||||
|
||||
5. **Suggestions** (`suggestion`): You have a recommendation but want approval
|
||||
- Example: "I recommend refactoring this code. Should I proceed?"
|
||||
- **REQUIRED ACTION**: Call ask_clarification to get approval
|
||||
|
||||
**STRICT ENFORCEMENT:**
|
||||
- ❌ DO NOT start working and then ask for clarification mid-execution - clarify FIRST
|
||||
- ❌ DO NOT skip clarification for "efficiency" - accuracy matters more than speed
|
||||
- ❌ DO NOT make assumptions when information is missing - ALWAYS ask
|
||||
- ❌ DO NOT proceed with guesses - STOP and call ask_clarification first
|
||||
- ✅ Analyze the request in thinking → Identify unclear aspects → Ask BEFORE any action
|
||||
- ✅ If you identify the need for clarification in your thinking, you MUST call the tool IMMEDIATELY
|
||||
- ✅ After calling ask_clarification, execution will be interrupted automatically
|
||||
- ✅ Wait for user response - do NOT continue with assumptions
|
||||
|
||||
**How to Use:**
|
||||
```python
|
||||
ask_clarification(
|
||||
question="Your specific question here?",
|
||||
clarification_type="missing_info", # or other type
|
||||
context="Why you need this information", # optional but recommended
|
||||
options=["option1", "option2"] # optional, for choices
|
||||
)
|
||||
```
|
||||
|
||||
**Example:**
|
||||
User: "Deploy the application"
|
||||
You (thinking): Missing environment info - I MUST ask for clarification
|
||||
You (action): ask_clarification(
|
||||
question="Which environment should I deploy to?",
|
||||
clarification_type="approach_choice",
|
||||
context="I need to know the target environment for proper configuration",
|
||||
options=["development", "staging", "production"]
|
||||
)
|
||||
[Execution stops - wait for user response]
|
||||
|
||||
User: "staging"
|
||||
You: "Deploying to staging..." [proceed]
|
||||
</clarification_system>
|
||||
|
||||
{skills_section}
|
||||
|
||||
{deferred_tools_section}
|
||||
|
||||
{subagent_section}
|
||||
|
||||
<working_directory existed="true">
|
||||
- User uploads: `/mnt/user-data/uploads` - Files uploaded by the user (automatically listed in context)
|
||||
- User workspace: `/mnt/user-data/workspace` - Working directory for temporary files
|
||||
- Output files: `/mnt/user-data/outputs` - Final deliverables must be saved here
|
||||
|
||||
**File Management:**
|
||||
- Uploaded files are automatically listed in the <uploaded_files> section before each request
|
||||
- Use `read_file` tool to read uploaded files using their paths from the list
|
||||
- For PDF, PPT, Excel, and Word files, converted Markdown versions (*.md) are available alongside originals
|
||||
- All temporary work happens in `/mnt/user-data/workspace`
|
||||
- Treat `/mnt/user-data/workspace` as your default current working directory for coding and file-editing tasks
|
||||
- When writing scripts or commands that create/read files from the workspace, prefer relative paths such as `hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`
|
||||
- Avoid hardcoding `/mnt/user-data/...` inside generated scripts when a relative path from the workspace is enough
|
||||
- Final deliverables must be copied to `/mnt/user-data/outputs` and presented using `present_file` tool
|
||||
{acp_section}
|
||||
</working_directory>
|
||||
|
||||
<response_style>
|
||||
- Clear and Concise: Avoid over-formatting unless requested
|
||||
- Natural Tone: Use paragraphs and prose, not bullet points by default
|
||||
- Action-Oriented: Focus on delivering results, not explaining processes
|
||||
</response_style>
|
||||
|
||||
<citations>
|
||||
**CRITICAL: Always include citations when using web search results**
|
||||
|
||||
- **When to Use**: MANDATORY after web_search, web_fetch, or any external information source
|
||||
- **Format**: Use Markdown link format `[citation:TITLE](URL)` immediately after the claim
|
||||
- **Placement**: Inline citations should appear right after the sentence or claim they support
|
||||
- **Sources Section**: Also collect all citations in a "Sources" section at the end of reports
|
||||
|
||||
**Example - Inline Citations:**
|
||||
```markdown
|
||||
The key AI trends for 2026 include enhanced reasoning capabilities and multimodal integration
|
||||
[citation:AI Trends 2026](https://techcrunch.com/ai-trends).
|
||||
Recent breakthroughs in language models have also accelerated progress
|
||||
[citation:OpenAI Research](https://openai.com/research).
|
||||
```
|
||||
|
||||
**Example - Deep Research Report with Citations:**
|
||||
```markdown
|
||||
## Executive Summary
|
||||
|
||||
DeerFlow is an open-source AI agent framework that gained significant traction in early 2026
|
||||
[citation:GitHub Repository](https://github.com/bytedance/deer-flow). The project focuses on
|
||||
providing a production-ready agent system with sandbox execution and memory management
|
||||
[citation:DeerFlow Documentation](https://deer-flow.dev/docs).
|
||||
|
||||
## Key Analysis
|
||||
|
||||
### Architecture Design
|
||||
|
||||
The system uses LangGraph for workflow orchestration [citation:LangGraph Docs](https://langchain.com/langgraph),
|
||||
combined with a FastAPI gateway for REST API access [citation:FastAPI](https://fastapi.tiangolo.com).
|
||||
|
||||
## Sources
|
||||
|
||||
### Primary Sources
|
||||
- [GitHub Repository](https://github.com/bytedance/deer-flow) - Official source code and documentation
|
||||
- [DeerFlow Documentation](https://deer-flow.dev/docs) - Technical specifications
|
||||
|
||||
### Media Coverage
|
||||
- [AI Trends 2026](https://techcrunch.com/ai-trends) - Industry analysis
|
||||
```
|
||||
|
||||
**CRITICAL: Sources section format:**
|
||||
- Every item in the Sources section MUST be a clickable markdown link with URL
|
||||
- Use standard markdown link `[Title](URL) - Description` format (NOT `[citation:...]` format)
|
||||
- The `[citation:Title](URL)` format is ONLY for inline citations within the report body
|
||||
- ❌ WRONG: `GitHub 仓库 - 官方源代码和文档` (no URL!)
|
||||
- ❌ WRONG in Sources: `[citation:GitHub Repository](url)` (citation prefix is for inline only!)
|
||||
- ✅ RIGHT in Sources: `[GitHub Repository](https://github.com/bytedance/deer-flow) - 官方源代码和文档`
|
||||
|
||||
**WORKFLOW for Research Tasks:**
|
||||
1. Use web_search to find sources → Extract {{title, url, snippet}} from results
|
||||
2. Write content with inline citations: `claim [citation:Title](url)`
|
||||
3. Collect all citations in a "Sources" section at the end
|
||||
4. NEVER write claims without citations when sources are available
|
||||
|
||||
**CRITICAL RULES:**
|
||||
- ❌ DO NOT write research content without citations
|
||||
- ❌ DO NOT forget to extract URLs from search results
|
||||
- ✅ ALWAYS add `[citation:Title](URL)` after claims from external sources
|
||||
- ✅ ALWAYS include a "Sources" section listing all references
|
||||
</citations>
|
||||
|
||||
<critical_reminders>
|
||||
- **Clarification First**: ALWAYS clarify unclear/missing/ambiguous requirements BEFORE starting work - never assume or guess
|
||||
{subagent_reminder}- Skill First: Always load the relevant skill before starting **complex** tasks.
|
||||
- Progressive Loading: Load resources incrementally as referenced in skills
|
||||
- Output Files: Final deliverables must be in `/mnt/user-data/outputs`
|
||||
- Clarity: Be direct and helpful, avoid unnecessary meta-commentary
|
||||
- Including Images and Mermaid: Images and Mermaid diagrams are always welcomed in the Markdown format, and you're encouraged to use `\n\n` or "```mermaid" to display images in response or Markdown files
|
||||
- Multi-task: Better utilize parallel tool calling to call multiple tools at one time for better performance
|
||||
- Language Consistency: Keep using the same language as user's
|
||||
- Always Respond: Your thinking is internal. You MUST always provide a visible response to the user after thinking.
|
||||
</critical_reminders>
|
||||
"""
|
||||
|
||||
|
||||
def _get_memory_context(agent_name: str | None = None) -> str:
|
||||
"""Get memory context for injection into system prompt.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, loads per-agent memory. If None, loads global memory.
|
||||
|
||||
Returns:
|
||||
Formatted memory context string wrapped in XML tags, or empty string if disabled.
|
||||
"""
|
||||
try:
|
||||
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
config = get_memory_config()
|
||||
if not config.enabled or not config.injection_enabled:
|
||||
return ""
|
||||
|
||||
memory_data = get_memory_data(agent_name)
|
||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
||||
|
||||
if not memory_content.strip():
|
||||
return ""
|
||||
|
||||
return f"""<memory>
|
||||
{memory_content}
|
||||
</memory>
|
||||
"""
|
||||
except Exception as e:
|
||||
logger.error("Failed to load memory context: %s", e)
|
||||
return ""
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_cached_skills_prompt_section(
|
||||
skill_signature: tuple[tuple[str, str, str, str], ...],
|
||||
available_skills_key: tuple[str, ...] | None,
|
||||
container_base_path: str,
|
||||
skill_evolution_section: str,
|
||||
) -> str:
|
||||
filtered = [(name, description, category, location) for name, description, category, location in skill_signature if available_skills_key is None or name in available_skills_key]
|
||||
skills_list = ""
|
||||
if filtered:
|
||||
skill_items = "\n".join(
|
||||
f" <skill>\n <name>{name}</name>\n <description>{description} {_skill_mutability_label(category)}</description>\n <location>{location}</location>\n </skill>"
|
||||
for name, description, category, location in filtered
|
||||
)
|
||||
skills_list = f"<available_skills>\n{skill_items}\n</available_skills>"
|
||||
return f"""<skill_system>
|
||||
You have access to skills that provide optimized workflows for specific tasks. Each skill contains best practices, frameworks, and references to additional resources.
|
||||
|
||||
**Progressive Loading Pattern:**
|
||||
1. When a user query matches a skill's use case, immediately call `read_file` on the skill's main file using the path attribute provided in the skill tag below
|
||||
2. Read and understand the skill's workflow and instructions
|
||||
3. The skill file contains references to external resources under the same folder
|
||||
4. Load referenced resources only when needed during execution
|
||||
5. Follow the skill's instructions precisely
|
||||
|
||||
**Skills are located at:** {container_base_path}
|
||||
{skill_evolution_section}
|
||||
{skills_list}
|
||||
|
||||
</skill_system>"""
|
||||
|
||||
|
||||
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||
"""Generate the skills prompt section with available skills list."""
|
||||
skills = _get_enabled_skills()
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = get_app_config()
|
||||
container_base_path = config.skills.container_path
|
||||
skill_evolution_enabled = config.skill_evolution.enabled
|
||||
except Exception:
|
||||
container_base_path = "/mnt/skills"
|
||||
skill_evolution_enabled = False
|
||||
|
||||
if not skills and not skill_evolution_enabled:
|
||||
return ""
|
||||
|
||||
if available_skills is not None and not any(skill.name in available_skills for skill in skills):
|
||||
return ""
|
||||
|
||||
skill_signature = tuple((skill.name, skill.description, skill.category, skill.get_container_file_path(container_base_path)) for skill in skills)
|
||||
available_key = tuple(sorted(available_skills)) if available_skills is not None else None
|
||||
if not skill_signature and available_key is not None:
|
||||
return ""
|
||||
skill_evolution_section = _build_skill_evolution_section(skill_evolution_enabled)
|
||||
return _get_cached_skills_prompt_section(skill_signature, available_key, container_base_path, skill_evolution_section)
|
||||
|
||||
|
||||
def get_agent_soul(agent_name: str | None) -> str:
|
||||
# Append SOUL.md (agent personality) if present
|
||||
soul = load_agent_soul(agent_name)
|
||||
if soul:
|
||||
return f"<soul>\n{soul}\n</soul>\n" if soul else ""
|
||||
return ""
|
||||
|
||||
|
||||
def get_deferred_tools_prompt_section() -> str:
|
||||
"""Generate <available-deferred-tools> block for the system prompt.
|
||||
|
||||
Lists only deferred tool names so the agent knows what exists
|
||||
and can use tool_search to load them.
|
||||
Returns empty string when tool_search is disabled or no tools are deferred.
|
||||
"""
|
||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
if not get_app_config().tool_search.enabled:
|
||||
return ""
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
registry = get_deferred_registry()
|
||||
if not registry:
|
||||
return ""
|
||||
|
||||
names = "\n".join(e.name for e in registry.entries)
|
||||
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
|
||||
|
||||
|
||||
def _build_acp_section() -> str:
|
||||
"""Build the ACP agent prompt section, only if ACP agents are configured."""
|
||||
try:
|
||||
from deerflow.config.acp_config import get_acp_agents
|
||||
|
||||
agents = get_acp_agents()
|
||||
if not agents:
|
||||
return ""
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
return (
|
||||
"\n**ACP Agent Tasks (invoke_acp_agent):**\n"
|
||||
"- ACP agents (e.g. codex, claude_code) run in their own independent workspace — NOT in `/mnt/user-data/`\n"
|
||||
"- When writing prompts for ACP agents, describe the task only — do NOT reference `/mnt/user-data` paths\n"
|
||||
"- ACP agent results are accessible at `/mnt/acp-workspace/` (read-only) — use `ls`, `read_file`, or `bash cp` to retrieve output files\n"
|
||||
"- To deliver ACP output to the user: copy from `/mnt/acp-workspace/<file>` to `/mnt/user-data/outputs/<file>`, then use `present_file`"
|
||||
)
|
||||
|
||||
|
||||
def _build_custom_mounts_section() -> str:
|
||||
"""Build a prompt section for explicitly configured sandbox mounts."""
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
mounts = get_app_config().sandbox.mounts or []
|
||||
except Exception:
|
||||
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
|
||||
return ""
|
||||
|
||||
if not mounts:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
for mount in mounts:
|
||||
access = "read-only" if mount.read_only else "read-write"
|
||||
lines.append(f"- Custom mount: `{mount.container_path}` - Host directory mapped into the sandbox ({access})")
|
||||
|
||||
mounts_list = "\n".join(lines)
|
||||
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
|
||||
|
||||
|
||||
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
|
||||
# Get memory context
|
||||
memory_context = _get_memory_context(agent_name)
|
||||
|
||||
# Include subagent section only if enabled (from runtime parameter)
|
||||
n = max_concurrent_subagents
|
||||
subagent_section = _build_subagent_section(n) if subagent_enabled else ""
|
||||
|
||||
# Add subagent reminder to critical_reminders if enabled
|
||||
subagent_reminder = (
|
||||
"- **Orchestrator Mode**: You are a task orchestrator - decompose complex tasks into parallel sub-tasks. "
|
||||
f"**HARD LIMIT: max {n} `task` calls per response.** "
|
||||
f"If >{n} sub-tasks, split into sequential batches of ≤{n}. Synthesize after ALL batches complete.\n"
|
||||
if subagent_enabled
|
||||
else ""
|
||||
)
|
||||
|
||||
# Add subagent thinking guidance if enabled
|
||||
subagent_thinking = (
|
||||
"- **DECOMPOSITION CHECK: Can this task be broken into 2+ parallel sub-tasks? If YES, COUNT them. "
|
||||
f"If count > {n}, you MUST plan batches of ≤{n} and only launch the FIRST batch now. "
|
||||
f"NEVER launch more than {n} `task` calls in one response.**\n"
|
||||
if subagent_enabled
|
||||
else ""
|
||||
)
|
||||
|
||||
# Get skills section
|
||||
skills_section = get_skills_prompt_section(available_skills)
|
||||
|
||||
# Get deferred tools section (tool_search)
|
||||
deferred_tools_section = get_deferred_tools_prompt_section()
|
||||
|
||||
# Build ACP agent section only if ACP agents are configured
|
||||
acp_section = _build_acp_section()
|
||||
custom_mounts_section = _build_custom_mounts_section()
|
||||
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
|
||||
|
||||
# Format the prompt with dynamic skills and memory
|
||||
prompt = SYSTEM_PROMPT_TEMPLATE.format(
|
||||
agent_name=agent_name or "DeerFlow 2.0",
|
||||
soul=get_agent_soul(agent_name),
|
||||
skills_section=skills_section,
|
||||
deferred_tools_section=deferred_tools_section,
|
||||
memory_context=memory_context,
|
||||
subagent_section=subagent_section,
|
||||
subagent_reminder=subagent_reminder,
|
||||
subagent_thinking=subagent_thinking,
|
||||
acp_section=acp_and_mounts_section,
|
||||
)
|
||||
|
||||
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Memory module for DeerFlow.
|
||||
|
||||
This module provides a global memory mechanism that:
|
||||
- Stores user context and conversation history in memory.json
|
||||
- Uses LLM to summarize and extract facts from conversations
|
||||
- Injects relevant memory into system prompts for personalized responses
|
||||
"""
|
||||
|
||||
from deerflow.agents.memory.prompt import (
|
||||
FACT_EXTRACTION_PROMPT,
|
||||
MEMORY_UPDATE_PROMPT,
|
||||
format_conversation_for_update,
|
||||
format_memory_for_injection,
|
||||
)
|
||||
from deerflow.agents.memory.queue import (
|
||||
ConversationContext,
|
||||
MemoryUpdateQueue,
|
||||
get_memory_queue,
|
||||
reset_memory_queue,
|
||||
)
|
||||
from deerflow.agents.memory.storage import (
|
||||
FileMemoryStorage,
|
||||
MemoryStorage,
|
||||
get_memory_storage,
|
||||
)
|
||||
from deerflow.agents.memory.updater import (
|
||||
MemoryUpdater,
|
||||
clear_memory_data,
|
||||
delete_memory_fact,
|
||||
get_memory_data,
|
||||
reload_memory_data,
|
||||
update_memory_from_conversation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Prompt utilities
|
||||
"MEMORY_UPDATE_PROMPT",
|
||||
"FACT_EXTRACTION_PROMPT",
|
||||
"format_memory_for_injection",
|
||||
"format_conversation_for_update",
|
||||
# Queue
|
||||
"ConversationContext",
|
||||
"MemoryUpdateQueue",
|
||||
"get_memory_queue",
|
||||
"reset_memory_queue",
|
||||
# Storage
|
||||
"MemoryStorage",
|
||||
"FileMemoryStorage",
|
||||
"get_memory_storage",
|
||||
# Updater
|
||||
"MemoryUpdater",
|
||||
"clear_memory_data",
|
||||
"delete_memory_fact",
|
||||
"get_memory_data",
|
||||
"reload_memory_data",
|
||||
"update_memory_from_conversation",
|
||||
]
|
||||
@@ -0,0 +1,363 @@
|
||||
"""Prompt templates for memory update and injection."""
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
TIKTOKEN_AVAILABLE = True
|
||||
except ImportError:
|
||||
TIKTOKEN_AVAILABLE = False
|
||||
|
||||
# Prompt template for updating memory based on conversation
|
||||
MEMORY_UPDATE_PROMPT = """You are a memory management system. Your task is to analyze a conversation and update the user's memory profile.
|
||||
|
||||
Current Memory State:
|
||||
<current_memory>
|
||||
{current_memory}
|
||||
</current_memory>
|
||||
|
||||
New Conversation to Process:
|
||||
<conversation>
|
||||
{conversation}
|
||||
</conversation>
|
||||
|
||||
Instructions:
|
||||
1. Analyze the conversation for important information about the user
|
||||
2. Extract relevant facts, preferences, and context with specific details (numbers, names, technologies)
|
||||
3. Update the memory sections as needed following the detailed length guidelines below
|
||||
|
||||
Before extracting facts, perform a structured reflection on the conversation:
|
||||
1. Error/Retry Detection: Did the agent encounter errors, require retries, or produce incorrect results?
|
||||
If yes, record the root cause and correct approach as a high-confidence fact with category "correction".
|
||||
2. User Correction Detection: Did the user correct the agent's direction, understanding, or output?
|
||||
If yes, record the correct interpretation or approach as a high-confidence fact with category "correction".
|
||||
Include what went wrong in "sourceError" only when category is "correction" and the mistake is explicit in the conversation.
|
||||
3. Project Constraint Discovery: Were any project-specific constraints discovered during the conversation?
|
||||
If yes, record them as facts with the most appropriate category and confidence.
|
||||
|
||||
{correction_hint}
|
||||
|
||||
Memory Section Guidelines:
|
||||
|
||||
**User Context** (Current state - concise summaries):
|
||||
- workContext: Professional role, company, key projects, main technologies (2-3 sentences)
|
||||
Example: Core contributor, project names with metrics (16k+ stars), technical stack
|
||||
- personalContext: Languages, communication preferences, key interests (1-2 sentences)
|
||||
Example: Bilingual capabilities, specific interest areas, expertise domains
|
||||
- topOfMind: Multiple ongoing focus areas and priorities (3-5 sentences, detailed paragraph)
|
||||
Example: Primary project work, parallel technical investigations, ongoing learning/tracking
|
||||
Include: Active implementation work, troubleshooting issues, market/research interests
|
||||
Note: This captures SEVERAL concurrent focus areas, not just one task
|
||||
|
||||
**History** (Temporal context - rich paragraphs):
|
||||
- recentMonths: Detailed summary of recent activities (4-6 sentences or 1-2 paragraphs)
|
||||
Timeline: Last 1-3 months of interactions
|
||||
Include: Technologies explored, projects worked on, problems solved, interests demonstrated
|
||||
- earlierContext: Important historical patterns (3-5 sentences or 1 paragraph)
|
||||
Timeline: 3-12 months ago
|
||||
Include: Past projects, learning journeys, established patterns
|
||||
- longTermBackground: Persistent background and foundational context (2-4 sentences)
|
||||
Timeline: Overall/foundational information
|
||||
Include: Core expertise, longstanding interests, fundamental working style
|
||||
|
||||
**Facts Extraction**:
|
||||
- Extract specific, quantifiable details (e.g., "16k+ GitHub stars", "200+ datasets")
|
||||
- Include proper nouns (company names, project names, technology names)
|
||||
- Preserve technical terminology and version numbers
|
||||
- Categories:
|
||||
* preference: Tools, styles, approaches user prefers/dislikes
|
||||
* knowledge: Specific expertise, technologies mastered, domain knowledge
|
||||
* context: Background facts (job title, projects, locations, languages)
|
||||
* behavior: Working patterns, communication habits, problem-solving approaches
|
||||
* goal: Stated objectives, learning targets, project ambitions
|
||||
* correction: Explicit agent mistakes or user corrections, including the correct approach
|
||||
- Confidence levels:
|
||||
* 0.9-1.0: Explicitly stated facts ("I work on X", "My role is Y")
|
||||
* 0.7-0.8: Strongly implied from actions/discussions
|
||||
* 0.5-0.6: Inferred patterns (use sparingly, only for clear patterns)
|
||||
|
||||
**What Goes Where**:
|
||||
- workContext: Current job, active projects, primary tech stack
|
||||
- personalContext: Languages, personality, interests outside direct work tasks
|
||||
- topOfMind: Multiple ongoing priorities and focus areas user cares about recently (gets updated most frequently)
|
||||
Should capture 3-5 concurrent themes: main work, side explorations, learning/tracking interests
|
||||
- recentMonths: Detailed account of recent technical explorations and work
|
||||
- earlierContext: Patterns from slightly older interactions still relevant
|
||||
- longTermBackground: Unchanging foundational facts about the user
|
||||
|
||||
**Multilingual Content**:
|
||||
- Preserve original language for proper nouns and company names
|
||||
- Keep technical terms in their original form (DeepSeek, LangGraph, etc.)
|
||||
- Note language capabilities in personalContext
|
||||
|
||||
Output Format (JSON):
|
||||
{{
|
||||
"user": {{
|
||||
"workContext": {{ "summary": "...", "shouldUpdate": true/false }},
|
||||
"personalContext": {{ "summary": "...", "shouldUpdate": true/false }},
|
||||
"topOfMind": {{ "summary": "...", "shouldUpdate": true/false }}
|
||||
}},
|
||||
"history": {{
|
||||
"recentMonths": {{ "summary": "...", "shouldUpdate": true/false }},
|
||||
"earlierContext": {{ "summary": "...", "shouldUpdate": true/false }},
|
||||
"longTermBackground": {{ "summary": "...", "shouldUpdate": true/false }}
|
||||
}},
|
||||
"newFacts": [
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
|
||||
],
|
||||
"factsToRemove": ["fact_id_1", "fact_id_2"]
|
||||
}}
|
||||
|
||||
Important Rules:
|
||||
- Only set shouldUpdate=true if there's meaningful new information
|
||||
- Follow length guidelines: workContext/personalContext are concise (1-3 sentences), topOfMind and history sections are detailed (paragraphs)
|
||||
- Include specific metrics, version numbers, and proper nouns in facts
|
||||
- Only add facts that are clearly stated (0.9+) or strongly implied (0.7+)
|
||||
- Use category "correction" for explicit agent mistakes or user corrections; assign confidence >= 0.95 when the correction is explicit
|
||||
- Include "sourceError" only for explicit correction facts when the prior mistake or wrong approach is clearly stated; omit it otherwise
|
||||
- Remove facts that are contradicted by new information
|
||||
- When updating topOfMind, integrate new focus areas while removing completed/abandoned ones
|
||||
Keep 3-5 concurrent focus themes that are still active and relevant
|
||||
- For history sections, integrate new information chronologically into appropriate time period
|
||||
- Preserve technical accuracy - keep exact names of technologies, companies, projects
|
||||
- Focus on information useful for future interactions and personalization
|
||||
- IMPORTANT: Do NOT record file upload events in memory. Uploaded files are
|
||||
session-specific and ephemeral — they will not be accessible in future sessions.
|
||||
Recording upload events causes confusion in subsequent conversations.
|
||||
|
||||
Return ONLY valid JSON, no explanation or markdown."""
|
||||
|
||||
|
||||
# Prompt template for extracting facts from a single message
|
||||
FACT_EXTRACTION_PROMPT = """Extract factual information about the user from this message.
|
||||
|
||||
Message:
|
||||
{message}
|
||||
|
||||
Extract facts in this JSON format:
|
||||
{{
|
||||
"facts": [
|
||||
{{ "content": "...", "category": "preference|knowledge|context|behavior|goal|correction", "confidence": 0.0-1.0 }}
|
||||
]
|
||||
}}
|
||||
|
||||
Categories:
|
||||
- preference: User preferences (likes/dislikes, styles, tools)
|
||||
- knowledge: User's expertise or knowledge areas
|
||||
- context: Background context (location, job, projects)
|
||||
- behavior: Behavioral patterns
|
||||
- goal: User's goals or objectives
|
||||
- correction: Explicit corrections or mistakes to avoid repeating
|
||||
|
||||
Rules:
|
||||
- Only extract clear, specific facts
|
||||
- Confidence should reflect certainty (explicit statement = 0.9+, implied = 0.6-0.8)
|
||||
- Skip vague or temporary information
|
||||
|
||||
Return ONLY valid JSON."""
|
||||
|
||||
|
||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
||||
"""Count tokens in text using tiktoken.
|
||||
|
||||
Args:
|
||||
text: The text to count tokens for.
|
||||
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
|
||||
|
||||
Returns:
|
||||
The number of tokens in the text.
|
||||
"""
|
||||
if not TIKTOKEN_AVAILABLE:
|
||||
# Fallback to character-based estimation if tiktoken is not available
|
||||
return len(text) // 4
|
||||
|
||||
try:
|
||||
encoding = tiktoken.get_encoding(encoding_name)
|
||||
return len(encoding.encode(text))
|
||||
except Exception:
|
||||
# Fallback to character-based estimation on error
|
||||
return len(text) // 4
|
||||
|
||||
|
||||
def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
||||
"""Coerce a confidence-like value to a bounded float in [0, 1].
|
||||
|
||||
Non-finite values (NaN, inf, -inf) are treated as invalid and fall back
|
||||
to the default before clamping, preventing them from dominating ranking.
|
||||
The ``default`` parameter is assumed to be a finite value.
|
||||
"""
|
||||
try:
|
||||
confidence = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return max(0.0, min(1.0, default))
|
||||
if not math.isfinite(confidence):
|
||||
return max(0.0, min(1.0, default))
|
||||
return max(0.0, min(1.0, confidence))
|
||||
|
||||
|
||||
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
||||
"""Format memory data for injection into system prompt.
|
||||
|
||||
Args:
|
||||
memory_data: The memory data dictionary.
|
||||
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
|
||||
|
||||
Returns:
|
||||
Formatted memory string for system prompt injection.
|
||||
"""
|
||||
if not memory_data:
|
||||
return ""
|
||||
|
||||
sections = []
|
||||
|
||||
# Format user context
|
||||
user_data = memory_data.get("user", {})
|
||||
if user_data:
|
||||
user_sections = []
|
||||
|
||||
work_ctx = user_data.get("workContext", {})
|
||||
if work_ctx.get("summary"):
|
||||
user_sections.append(f"Work: {work_ctx['summary']}")
|
||||
|
||||
personal_ctx = user_data.get("personalContext", {})
|
||||
if personal_ctx.get("summary"):
|
||||
user_sections.append(f"Personal: {personal_ctx['summary']}")
|
||||
|
||||
top_of_mind = user_data.get("topOfMind", {})
|
||||
if top_of_mind.get("summary"):
|
||||
user_sections.append(f"Current Focus: {top_of_mind['summary']}")
|
||||
|
||||
if user_sections:
|
||||
sections.append("User Context:\n" + "\n".join(f"- {s}" for s in user_sections))
|
||||
|
||||
# Format history
|
||||
history_data = memory_data.get("history", {})
|
||||
if history_data:
|
||||
history_sections = []
|
||||
|
||||
recent = history_data.get("recentMonths", {})
|
||||
if recent.get("summary"):
|
||||
history_sections.append(f"Recent: {recent['summary']}")
|
||||
|
||||
earlier = history_data.get("earlierContext", {})
|
||||
if earlier.get("summary"):
|
||||
history_sections.append(f"Earlier: {earlier['summary']}")
|
||||
|
||||
background = history_data.get("longTermBackground", {})
|
||||
if background.get("summary"):
|
||||
history_sections.append(f"Background: {background['summary']}")
|
||||
|
||||
if history_sections:
|
||||
sections.append("History:\n" + "\n".join(f"- {s}" for s in history_sections))
|
||||
|
||||
# Format facts (sorted by confidence; include as many as token budget allows)
|
||||
facts_data = memory_data.get("facts", [])
|
||||
if isinstance(facts_data, list) and facts_data:
|
||||
ranked_facts = sorted(
|
||||
(f for f in facts_data if isinstance(f, dict) and isinstance(f.get("content"), str) and f.get("content").strip()),
|
||||
key=lambda fact: _coerce_confidence(fact.get("confidence"), default=0.0),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Compute token count for existing sections once, then account
|
||||
# incrementally for each fact line to avoid full-string re-tokenization.
|
||||
base_text = "\n\n".join(sections)
|
||||
base_tokens = _count_tokens(base_text) if base_text else 0
|
||||
# Account for the separator between existing sections and the facts section.
|
||||
facts_header = "Facts:\n"
|
||||
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header)
|
||||
running_tokens = base_tokens + separator_tokens
|
||||
|
||||
fact_lines: list[str] = []
|
||||
for fact in ranked_facts:
|
||||
content_value = fact.get("content")
|
||||
if not isinstance(content_value, str):
|
||||
continue
|
||||
content = content_value.strip()
|
||||
if not content:
|
||||
continue
|
||||
category = str(fact.get("category", "context")).strip() or "context"
|
||||
confidence = _coerce_confidence(fact.get("confidence"), default=0.0)
|
||||
source_error = fact.get("sourceError")
|
||||
if category == "correction" and isinstance(source_error, str) and source_error.strip():
|
||||
line = f"- [{category} | {confidence:.2f}] {content} (avoid: {source_error.strip()})"
|
||||
else:
|
||||
line = f"- [{category} | {confidence:.2f}] {content}"
|
||||
|
||||
# Each additional line is preceded by a newline (except the first).
|
||||
line_text = ("\n" + line) if fact_lines else line
|
||||
line_tokens = _count_tokens(line_text)
|
||||
|
||||
if running_tokens + line_tokens <= max_tokens:
|
||||
fact_lines.append(line)
|
||||
running_tokens += line_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
if fact_lines:
|
||||
sections.append("Facts:\n" + "\n".join(fact_lines))
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
result = "\n\n".join(sections)
|
||||
|
||||
# Use accurate token counting with tiktoken
|
||||
token_count = _count_tokens(result)
|
||||
if token_count > max_tokens:
|
||||
# Truncate to fit within token limit
|
||||
# Estimate characters to remove based on token ratio
|
||||
char_per_token = len(result) / token_count
|
||||
target_chars = int(max_tokens * char_per_token * 0.95) # 95% to leave margin
|
||||
result = result[:target_chars] + "\n..."
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def format_conversation_for_update(messages: list[Any]) -> str:
|
||||
"""Format conversation messages for memory update prompt.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
|
||||
Returns:
|
||||
Formatted conversation string.
|
||||
"""
|
||||
lines = []
|
||||
for msg in messages:
|
||||
role = getattr(msg, "type", "unknown")
|
||||
content = getattr(msg, "content", str(msg))
|
||||
|
||||
# Handle content that might be a list (multimodal)
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for p in content:
|
||||
if isinstance(p, str):
|
||||
text_parts.append(p)
|
||||
elif isinstance(p, dict):
|
||||
text_val = p.get("text")
|
||||
if isinstance(text_val, str):
|
||||
text_parts.append(text_val)
|
||||
content = " ".join(text_parts) if text_parts else str(content)
|
||||
|
||||
# Strip uploaded_files tags from human messages to avoid persisting
|
||||
# ephemeral file path info into long-term memory. Skip the turn entirely
|
||||
# when nothing remains after stripping (upload-only message).
|
||||
if role == "human":
|
||||
content = re.sub(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", "", str(content)).strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# Truncate very long messages
|
||||
if len(str(content)) > 1000:
|
||||
content = str(content)[:1000] + "..."
|
||||
|
||||
if role == "human":
|
||||
lines.append(f"User: {content}")
|
||||
elif role == "ai":
|
||||
lines.append(f"Assistant: {content}")
|
||||
|
||||
return "\n\n".join(lines)
|
||||
@@ -0,0 +1,219 @@
|
||||
"""Memory update queue with debounce mechanism."""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationContext:
|
||||
"""Context for a conversation to be processed for memory update."""
|
||||
|
||||
thread_id: str
|
||||
messages: list[Any]
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
agent_name: str | None = None
|
||||
correction_detected: bool = False
|
||||
reinforcement_detected: bool = False
|
||||
|
||||
|
||||
class MemoryUpdateQueue:
|
||||
"""Queue for memory updates with debounce mechanism.
|
||||
|
||||
This queue collects conversation contexts and processes them after
|
||||
a configurable debounce period. Multiple conversations received within
|
||||
the debounce window are batched together.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the memory update queue."""
|
||||
self._queue: list[ConversationContext] = []
|
||||
self._lock = threading.Lock()
|
||||
self._timer: threading.Timer | None = None
|
||||
self._processing = False
|
||||
|
||||
def add(
|
||||
self,
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> None:
|
||||
"""Add a conversation to the update queue.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
existing_context = next(
|
||||
(context for context in self._queue if context.thread_id == thread_id),
|
||||
None,
|
||||
)
|
||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
correction_detected=merged_correction_detected,
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
)
|
||||
|
||||
# Check if this thread already has a pending update
|
||||
# If so, replace it with the newer one
|
||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||
self._queue.append(context)
|
||||
|
||||
# Reset or start the debounce timer
|
||||
self._reset_timer()
|
||||
|
||||
logger.info("Memory update queued for thread %s, queue size: %d", thread_id, len(self._queue))
|
||||
|
||||
def _reset_timer(self) -> None:
|
||||
"""Reset the debounce timer."""
|
||||
config = get_memory_config()
|
||||
|
||||
# Cancel existing timer if any
|
||||
if self._timer is not None:
|
||||
self._timer.cancel()
|
||||
|
||||
# Start new timer
|
||||
self._timer = threading.Timer(
|
||||
config.debounce_seconds,
|
||||
self._process_queue,
|
||||
)
|
||||
self._timer.daemon = True
|
||||
self._timer.start()
|
||||
|
||||
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
|
||||
|
||||
def _process_queue(self) -> None:
|
||||
"""Process all queued conversation contexts."""
|
||||
# Import here to avoid circular dependency
|
||||
from deerflow.agents.memory.updater import MemoryUpdater
|
||||
|
||||
with self._lock:
|
||||
if self._processing:
|
||||
# Already processing, reschedule
|
||||
self._reset_timer()
|
||||
return
|
||||
|
||||
if not self._queue:
|
||||
return
|
||||
|
||||
self._processing = True
|
||||
contexts_to_process = self._queue.copy()
|
||||
self._queue.clear()
|
||||
self._timer = None
|
||||
|
||||
logger.info("Processing %d queued memory updates", len(contexts_to_process))
|
||||
|
||||
try:
|
||||
updater = MemoryUpdater()
|
||||
|
||||
for context in contexts_to_process:
|
||||
try:
|
||||
logger.info("Updating memory for thread %s", context.thread_id)
|
||||
success = updater.update_memory(
|
||||
messages=context.messages,
|
||||
thread_id=context.thread_id,
|
||||
agent_name=context.agent_name,
|
||||
correction_detected=context.correction_detected,
|
||||
reinforcement_detected=context.reinforcement_detected,
|
||||
)
|
||||
if success:
|
||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||
else:
|
||||
logger.warning("Memory update skipped/failed for thread %s", context.thread_id)
|
||||
except Exception as e:
|
||||
logger.error("Error updating memory for thread %s: %s", context.thread_id, e)
|
||||
|
||||
# Small delay between updates to avoid rate limiting
|
||||
if len(contexts_to_process) > 1:
|
||||
time.sleep(0.5)
|
||||
|
||||
finally:
|
||||
with self._lock:
|
||||
self._processing = False
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Force immediate processing of the queue.
|
||||
|
||||
This is useful for testing or graceful shutdown.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._timer is not None:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
|
||||
self._process_queue()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the queue without processing.
|
||||
|
||||
This is useful for testing.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._timer is not None:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
self._queue.clear()
|
||||
self._processing = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
"""Get the number of pending updates."""
|
||||
with self._lock:
|
||||
return len(self._queue)
|
||||
|
||||
@property
|
||||
def is_processing(self) -> bool:
|
||||
"""Check if the queue is currently being processed."""
|
||||
with self._lock:
|
||||
return self._processing
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_memory_queue: MemoryUpdateQueue | None = None
|
||||
_queue_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_memory_queue() -> MemoryUpdateQueue:
|
||||
"""Get the global memory update queue singleton.
|
||||
|
||||
Returns:
|
||||
The memory update queue instance.
|
||||
"""
|
||||
global _memory_queue
|
||||
with _queue_lock:
|
||||
if _memory_queue is None:
|
||||
_memory_queue = MemoryUpdateQueue()
|
||||
return _memory_queue
|
||||
|
||||
|
||||
def reset_memory_queue() -> None:
|
||||
"""Reset the global memory queue.
|
||||
|
||||
This is useful for testing.
|
||||
"""
|
||||
global _memory_queue
|
||||
with _queue_lock:
|
||||
if _memory_queue is not None:
|
||||
_memory_queue.clear()
|
||||
_memory_queue = None
|
||||
@@ -0,0 +1,205 @@
|
||||
"""Memory storage providers."""
|
||||
|
||||
import abc
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config.agents_config import AGENT_NAME_PATTERN
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def utc_now_iso_z() -> str:
|
||||
"""Current UTC time as ISO-8601 with ``Z`` suffix (matches prior naive-UTC output)."""
|
||||
return datetime.now(UTC).isoformat().removesuffix("+00:00") + "Z"
|
||||
|
||||
|
||||
def create_empty_memory() -> dict[str, Any]:
|
||||
"""Create an empty memory structure."""
|
||||
return {
|
||||
"version": "1.0",
|
||||
"lastUpdated": utc_now_iso_z(),
|
||||
"user": {
|
||||
"workContext": {"summary": "", "updatedAt": ""},
|
||||
"personalContext": {"summary": "", "updatedAt": ""},
|
||||
"topOfMind": {"summary": "", "updatedAt": ""},
|
||||
},
|
||||
"history": {
|
||||
"recentMonths": {"summary": "", "updatedAt": ""},
|
||||
"earlierContext": {"summary": "", "updatedAt": ""},
|
||||
"longTermBackground": {"summary": "", "updatedAt": ""},
|
||||
},
|
||||
"facts": [],
|
||||
}
|
||||
|
||||
|
||||
class MemoryStorage(abc.ABC):
|
||||
"""Abstract base class for memory storage providers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data for the given agent."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Force reload memory data for the given agent."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
"""Save memory data for the given agent."""
|
||||
pass
|
||||
|
||||
|
||||
class FileMemoryStorage(MemoryStorage):
|
||||
"""File-based memory storage provider."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the file memory storage."""
|
||||
# Per-agent memory cache: keyed by agent_name (None = global)
|
||||
# Value: (memory_data, file_mtime)
|
||||
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
|
||||
|
||||
def _validate_agent_name(self, agent_name: str) -> None:
|
||||
"""Validate that the agent name is safe to use in filesystem paths.
|
||||
|
||||
Uses the repository's established AGENT_NAME_PATTERN to ensure consistency
|
||||
across the codebase and prevent path traversal or other problematic characters.
|
||||
"""
|
||||
if not agent_name:
|
||||
raise ValueError("Agent name must be a non-empty string.")
|
||||
if not AGENT_NAME_PATTERN.match(agent_name):
|
||||
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
|
||||
|
||||
def _get_memory_file_path(self, agent_name: str | None = None) -> Path:
|
||||
"""Get the path to the memory file."""
|
||||
if agent_name is not None:
|
||||
self._validate_agent_name(agent_name)
|
||||
return get_paths().agent_memory_file(agent_name)
|
||||
|
||||
config = get_memory_config()
|
||||
if config.storage_path:
|
||||
p = Path(config.storage_path)
|
||||
return p if p.is_absolute() else get_paths().base_dir / p
|
||||
return get_paths().memory_file
|
||||
|
||||
def _load_memory_from_file(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data from file."""
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
|
||||
if not file_path.exists():
|
||||
return create_empty_memory()
|
||||
|
||||
try:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning("Failed to load memory file: %s", e)
|
||||
return create_empty_memory()
|
||||
|
||||
def load(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data (cached with file modification time check)."""
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
|
||||
try:
|
||||
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
except OSError:
|
||||
current_mtime = None
|
||||
|
||||
cached = self._memory_cache.get(agent_name)
|
||||
|
||||
if cached is None or cached[1] != current_mtime:
|
||||
memory_data = self._load_memory_from_file(agent_name)
|
||||
self._memory_cache[agent_name] = (memory_data, current_mtime)
|
||||
return memory_data
|
||||
|
||||
return cached[0]
|
||||
|
||||
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Reload memory data from file, forcing cache invalidation."""
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
memory_data = self._load_memory_from_file(agent_name)
|
||||
|
||||
try:
|
||||
mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
except OSError:
|
||||
mtime = None
|
||||
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
return memory_data
|
||||
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
"""Save memory data to file and update cache."""
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
|
||||
try:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
memory_data["lastUpdated"] = utc_now_iso_z()
|
||||
|
||||
temp_path = file_path.with_suffix(".tmp")
|
||||
with open(temp_path, "w", encoding="utf-8") as f:
|
||||
json.dump(memory_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
temp_path.replace(file_path)
|
||||
|
||||
try:
|
||||
mtime = file_path.stat().st_mtime
|
||||
except OSError:
|
||||
mtime = None
|
||||
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
logger.info("Memory saved to %s", file_path)
|
||||
return True
|
||||
except OSError as e:
|
||||
logger.error("Failed to save memory file: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
_storage_instance: MemoryStorage | None = None
|
||||
_storage_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_memory_storage() -> MemoryStorage:
|
||||
"""Get the configured memory storage instance."""
|
||||
global _storage_instance
|
||||
if _storage_instance is not None:
|
||||
return _storage_instance
|
||||
|
||||
with _storage_lock:
|
||||
if _storage_instance is not None:
|
||||
return _storage_instance
|
||||
|
||||
config = get_memory_config()
|
||||
storage_class_path = config.storage_class
|
||||
|
||||
try:
|
||||
module_path, class_name = storage_class_path.rsplit(".", 1)
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
storage_class = getattr(module, class_name)
|
||||
|
||||
# Validate that the configured storage is a MemoryStorage implementation
|
||||
if not isinstance(storage_class, type):
|
||||
raise TypeError(f"Configured memory storage '{storage_class_path}' is not a class: {storage_class!r}")
|
||||
if not issubclass(storage_class, MemoryStorage):
|
||||
raise TypeError(f"Configured memory storage '{storage_class_path}' is not a subclass of MemoryStorage")
|
||||
|
||||
_storage_instance = storage_class()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load memory storage %s, falling back to FileMemoryStorage: %s",
|
||||
storage_class_path,
|
||||
e,
|
||||
)
|
||||
_storage_instance = FileMemoryStorage()
|
||||
|
||||
return _storage_instance
|
||||
@@ -0,0 +1,472 @@
|
||||
"""Memory updater for reading, writing, and updating memory data."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from deerflow.agents.memory.prompt import (
|
||||
MEMORY_UPDATE_PROMPT,
|
||||
format_conversation_for_update,
|
||||
)
|
||||
from deerflow.agents.memory.storage import (
|
||||
create_empty_memory,
|
||||
get_memory_storage,
|
||||
utc_now_iso_z,
|
||||
)
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_empty_memory() -> dict[str, Any]:
|
||||
"""Backward-compatible wrapper around the storage-layer empty-memory factory."""
|
||||
return create_empty_memory()
|
||||
|
||||
|
||||
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
"""Backward-compatible wrapper around the configured memory storage save path."""
|
||||
return get_memory_storage().save(memory_data, agent_name)
|
||||
|
||||
|
||||
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Get the current memory data via storage provider."""
|
||||
return get_memory_storage().load(agent_name)
|
||||
|
||||
|
||||
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Reload memory data via storage provider."""
|
||||
return get_memory_storage().reload(agent_name)
|
||||
|
||||
|
||||
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Persist imported memory data via storage provider.
|
||||
|
||||
Args:
|
||||
memory_data: Full memory payload to persist.
|
||||
agent_name: If provided, imports into per-agent memory.
|
||||
|
||||
Returns:
|
||||
The saved memory data after storage normalization.
|
||||
|
||||
Raises:
|
||||
OSError: If persisting the imported memory fails.
|
||||
"""
|
||||
storage = get_memory_storage()
|
||||
if not storage.save(memory_data, agent_name):
|
||||
raise OSError("Failed to save imported memory data")
|
||||
return storage.load(agent_name)
|
||||
|
||||
|
||||
def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Clear all stored memory data and persist an empty structure."""
|
||||
cleared_memory = create_empty_memory()
|
||||
if not _save_memory_to_file(cleared_memory, agent_name):
|
||||
raise OSError("Failed to save cleared memory data")
|
||||
return cleared_memory
|
||||
|
||||
|
||||
def _validate_confidence(confidence: float) -> float:
|
||||
"""Validate persisted fact confidence so stored JSON stays standards-compliant."""
|
||||
if not math.isfinite(confidence) or confidence < 0 or confidence > 1:
|
||||
raise ValueError("confidence")
|
||||
return confidence
|
||||
|
||||
|
||||
def create_memory_fact(
|
||||
content: str,
|
||||
category: str = "context",
|
||||
confidence: float = 0.5,
|
||||
agent_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new fact and persist the updated memory data."""
|
||||
normalized_content = content.strip()
|
||||
if not normalized_content:
|
||||
raise ValueError("content")
|
||||
|
||||
normalized_category = category.strip() or "context"
|
||||
validated_confidence = _validate_confidence(confidence)
|
||||
now = utc_now_iso_z()
|
||||
memory_data = get_memory_data(agent_name)
|
||||
updated_memory = dict(memory_data)
|
||||
facts = list(memory_data.get("facts", []))
|
||||
facts.append(
|
||||
{
|
||||
"id": f"fact_{uuid.uuid4().hex[:8]}",
|
||||
"content": normalized_content,
|
||||
"category": normalized_category,
|
||||
"confidence": validated_confidence,
|
||||
"createdAt": now,
|
||||
"source": "manual",
|
||||
}
|
||||
)
|
||||
updated_memory["facts"] = facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name):
|
||||
raise OSError("Failed to save memory data after creating fact")
|
||||
|
||||
return updated_memory
|
||||
|
||||
|
||||
def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Delete a fact by its id and persist the updated memory data."""
|
||||
memory_data = get_memory_data(agent_name)
|
||||
facts = memory_data.get("facts", [])
|
||||
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
|
||||
if len(updated_facts) == len(facts):
|
||||
raise KeyError(fact_id)
|
||||
|
||||
updated_memory = dict(memory_data)
|
||||
updated_memory["facts"] = updated_facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name):
|
||||
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
|
||||
|
||||
return updated_memory
|
||||
|
||||
|
||||
def update_memory_fact(
|
||||
fact_id: str,
|
||||
content: str | None = None,
|
||||
category: str | None = None,
|
||||
confidence: float | None = None,
|
||||
agent_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing fact and persist the updated memory data."""
|
||||
memory_data = get_memory_data(agent_name)
|
||||
updated_memory = dict(memory_data)
|
||||
updated_facts: list[dict[str, Any]] = []
|
||||
found = False
|
||||
|
||||
for fact in memory_data.get("facts", []):
|
||||
if fact.get("id") == fact_id:
|
||||
found = True
|
||||
updated_fact = dict(fact)
|
||||
if content is not None:
|
||||
normalized_content = content.strip()
|
||||
if not normalized_content:
|
||||
raise ValueError("content")
|
||||
updated_fact["content"] = normalized_content
|
||||
if category is not None:
|
||||
updated_fact["category"] = category.strip() or "context"
|
||||
if confidence is not None:
|
||||
updated_fact["confidence"] = _validate_confidence(confidence)
|
||||
updated_facts.append(updated_fact)
|
||||
else:
|
||||
updated_facts.append(fact)
|
||||
|
||||
if not found:
|
||||
raise KeyError(fact_id)
|
||||
|
||||
updated_memory["facts"] = updated_facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name):
|
||||
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
|
||||
|
||||
return updated_memory
|
||||
|
||||
|
||||
def _extract_text(content: Any) -> str:
|
||||
"""Extract plain text from LLM response content (str or list of content blocks).
|
||||
|
||||
Modern LLMs may return structured content as a list of blocks instead of a
|
||||
plain string, e.g. [{"type": "text", "text": "..."}]. Using str() on such
|
||||
content produces Python repr instead of the actual text, breaking JSON
|
||||
parsing downstream.
|
||||
|
||||
String chunks are concatenated without separators to avoid corrupting
|
||||
chunked JSON/text payloads. Dict-based text blocks are treated as full text
|
||||
blocks and joined with newlines for readability.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
pieces: list[str] = []
|
||||
pending_str_parts: list[str] = []
|
||||
|
||||
def flush_pending_str_parts() -> None:
|
||||
if pending_str_parts:
|
||||
pieces.append("".join(pending_str_parts))
|
||||
pending_str_parts.clear()
|
||||
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
pending_str_parts.append(block)
|
||||
elif isinstance(block, dict):
|
||||
flush_pending_str_parts()
|
||||
text_val = block.get("text")
|
||||
if isinstance(text_val, str):
|
||||
pieces.append(text_val)
|
||||
|
||||
flush_pending_str_parts()
|
||||
return "\n".join(pieces)
|
||||
return str(content)
|
||||
|
||||
|
||||
# Matches sentences that describe a file-upload *event* rather than general
|
||||
# file-related work. Deliberately narrow to avoid removing legitimate facts
|
||||
# such as "User works with CSV files" or "prefers PDF export".
|
||||
_UPLOAD_SENTENCE_RE = re.compile(
|
||||
r"[^.!?]*\b(?:"
|
||||
r"upload(?:ed|ing)?(?:\s+\w+){0,3}\s+(?:file|files?|document|documents?|attachment|attachments?)"
|
||||
r"|file\s+upload"
|
||||
r"|/mnt/user-data/uploads/"
|
||||
r"|<uploaded_files>"
|
||||
r")[^.!?]*[.!?]?\s*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _strip_upload_mentions_from_memory(memory_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Remove sentences about file uploads from all memory summaries and facts.
|
||||
|
||||
Uploaded files are session-scoped; persisting upload events in long-term
|
||||
memory causes the agent to search for non-existent files in future sessions.
|
||||
"""
|
||||
# Scrub summaries in user/history sections
|
||||
for section in ("user", "history"):
|
||||
section_data = memory_data.get(section, {})
|
||||
for _key, val in section_data.items():
|
||||
if isinstance(val, dict) and "summary" in val:
|
||||
cleaned = _UPLOAD_SENTENCE_RE.sub("", val["summary"]).strip()
|
||||
cleaned = re.sub(r" +", " ", cleaned)
|
||||
val["summary"] = cleaned
|
||||
|
||||
# Also remove any facts that describe upload events
|
||||
facts = memory_data.get("facts", [])
|
||||
if facts:
|
||||
memory_data["facts"] = [f for f in facts if not _UPLOAD_SENTENCE_RE.search(f.get("content", ""))]
|
||||
|
||||
return memory_data
|
||||
|
||||
|
||||
def _fact_content_key(content: Any) -> str | None:
|
||||
if not isinstance(content, str):
|
||||
return None
|
||||
stripped = content.strip()
|
||||
if not stripped:
|
||||
return None
|
||||
return stripped.casefold()
|
||||
|
||||
|
||||
class MemoryUpdater:
|
||||
"""Updates memory using LLM based on conversation context."""
|
||||
|
||||
def __init__(self, model_name: str | None = None):
|
||||
"""Initialize the memory updater.
|
||||
|
||||
Args:
|
||||
model_name: Optional model name to use. If None, uses config or default.
|
||||
"""
|
||||
self._model_name = model_name
|
||||
|
||||
def _get_model(self):
|
||||
"""Get the model for memory updates."""
|
||||
config = get_memory_config()
|
||||
model_name = self._model_name or config.model_name
|
||||
return create_chat_model(name=model_name, thinking_enabled=False)
|
||||
|
||||
def update_memory(
|
||||
self,
|
||||
messages: list[Any],
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Update memory based on conversation messages.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID for tracking source.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
return False
|
||||
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Get current memory
|
||||
current_memory = get_memory_data(agent_name)
|
||||
|
||||
# Format conversation for prompt
|
||||
conversation_text = format_conversation_for_update(messages)
|
||||
|
||||
if not conversation_text.strip():
|
||||
return False
|
||||
|
||||
# Build prompt
|
||||
correction_hint = ""
|
||||
if correction_detected:
|
||||
correction_hint = (
|
||||
"IMPORTANT: Explicit correction signals were detected in this conversation. "
|
||||
"Pay special attention to what the agent got wrong, what the user corrected, "
|
||||
"and record the correct approach as a fact with category "
|
||||
'"correction" and confidence >= 0.95 when appropriate.'
|
||||
)
|
||||
if reinforcement_detected:
|
||||
reinforcement_hint = (
|
||||
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
|
||||
"The user explicitly confirmed the agent's approach was correct or helpful. "
|
||||
"Record the confirmed approach, style, or preference as a fact with category "
|
||||
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
|
||||
)
|
||||
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
|
||||
|
||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||
current_memory=json.dumps(current_memory, indent=2),
|
||||
conversation=conversation_text,
|
||||
correction_hint=correction_hint,
|
||||
)
|
||||
|
||||
# Call LLM
|
||||
model = self._get_model()
|
||||
response = model.invoke(prompt)
|
||||
response_text = _extract_text(response.content).strip()
|
||||
|
||||
# Parse response
|
||||
# Remove markdown code blocks if present
|
||||
if response_text.startswith("```"):
|
||||
lines = response_text.split("\n")
|
||||
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
||||
|
||||
update_data = json.loads(response_text)
|
||||
|
||||
# Apply updates
|
||||
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
|
||||
|
||||
# Strip file-upload mentions from all summaries before saving.
|
||||
# Uploaded files are session-scoped and won't exist in future sessions,
|
||||
# so recording upload events in long-term memory causes the agent to
|
||||
# try (and fail) to locate those files in subsequent conversations.
|
||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||
|
||||
# Save
|
||||
return get_memory_storage().save(updated_memory, agent_name)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception("Memory update failed: %s", e)
|
||||
return False
|
||||
|
||||
def _apply_updates(
|
||||
self,
|
||||
current_memory: dict[str, Any],
|
||||
update_data: dict[str, Any],
|
||||
thread_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Apply LLM-generated updates to memory.
|
||||
|
||||
Args:
|
||||
current_memory: Current memory data.
|
||||
update_data: Updates from LLM.
|
||||
thread_id: Optional thread ID for tracking.
|
||||
|
||||
Returns:
|
||||
Updated memory data.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
now = utc_now_iso_z()
|
||||
|
||||
# Update user sections
|
||||
user_updates = update_data.get("user", {})
|
||||
for section in ["workContext", "personalContext", "topOfMind"]:
|
||||
section_data = user_updates.get(section, {})
|
||||
if section_data.get("shouldUpdate") and section_data.get("summary"):
|
||||
current_memory["user"][section] = {
|
||||
"summary": section_data["summary"],
|
||||
"updatedAt": now,
|
||||
}
|
||||
|
||||
# Update history sections
|
||||
history_updates = update_data.get("history", {})
|
||||
for section in ["recentMonths", "earlierContext", "longTermBackground"]:
|
||||
section_data = history_updates.get(section, {})
|
||||
if section_data.get("shouldUpdate") and section_data.get("summary"):
|
||||
current_memory["history"][section] = {
|
||||
"summary": section_data["summary"],
|
||||
"updatedAt": now,
|
||||
}
|
||||
|
||||
# Remove facts
|
||||
facts_to_remove = set(update_data.get("factsToRemove", []))
|
||||
if facts_to_remove:
|
||||
current_memory["facts"] = [f for f in current_memory.get("facts", []) if f.get("id") not in facts_to_remove]
|
||||
|
||||
# Add new facts
|
||||
existing_fact_keys = {fact_key for fact_key in (_fact_content_key(fact.get("content")) for fact in current_memory.get("facts", [])) if fact_key is not None}
|
||||
new_facts = update_data.get("newFacts", [])
|
||||
for fact in new_facts:
|
||||
confidence = fact.get("confidence", 0.5)
|
||||
if confidence >= config.fact_confidence_threshold:
|
||||
raw_content = fact.get("content", "")
|
||||
if not isinstance(raw_content, str):
|
||||
continue
|
||||
normalized_content = raw_content.strip()
|
||||
fact_key = _fact_content_key(normalized_content)
|
||||
if fact_key is not None and fact_key in existing_fact_keys:
|
||||
continue
|
||||
|
||||
fact_entry = {
|
||||
"id": f"fact_{uuid.uuid4().hex[:8]}",
|
||||
"content": normalized_content,
|
||||
"category": fact.get("category", "context"),
|
||||
"confidence": confidence,
|
||||
"createdAt": now,
|
||||
"source": thread_id or "unknown",
|
||||
}
|
||||
source_error = fact.get("sourceError")
|
||||
if isinstance(source_error, str):
|
||||
normalized_source_error = source_error.strip()
|
||||
if normalized_source_error:
|
||||
fact_entry["sourceError"] = normalized_source_error
|
||||
current_memory["facts"].append(fact_entry)
|
||||
if fact_key is not None:
|
||||
existing_fact_keys.add(fact_key)
|
||||
|
||||
# Enforce max facts limit
|
||||
if len(current_memory["facts"]) > config.max_facts:
|
||||
# Sort by confidence and keep top ones
|
||||
current_memory["facts"] = sorted(
|
||||
current_memory["facts"],
|
||||
key=lambda f: f.get("confidence", 0),
|
||||
reverse=True,
|
||||
)[: config.max_facts]
|
||||
|
||||
return current_memory
|
||||
|
||||
|
||||
def update_memory_from_conversation(
|
||||
messages: list[Any],
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Convenience function to update memory from a conversation.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
"""Middleware for intercepting clarification requests and presenting them to the user."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.graph import END
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClarificationMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
|
||||
"""Intercepts clarification tool calls and interrupts execution to present questions to the user.
|
||||
|
||||
When the model calls the `ask_clarification` tool, this middleware:
|
||||
1. Intercepts the tool call before execution
|
||||
2. Extracts the clarification question and metadata
|
||||
3. Formats a user-friendly message
|
||||
4. Returns a Command that interrupts execution and presents the question
|
||||
5. Waits for user response before continuing
|
||||
|
||||
This replaces the tool-based approach where clarification continued the conversation flow.
|
||||
"""
|
||||
|
||||
state_schema = ClarificationMiddlewareState
|
||||
|
||||
def _is_chinese(self, text: str) -> bool:
|
||||
"""Check if text contains Chinese characters.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text contains Chinese characters
|
||||
"""
|
||||
return any("\u4e00" <= char <= "\u9fff" for char in text)
|
||||
|
||||
def _format_clarification_message(self, args: dict) -> str:
|
||||
"""Format the clarification arguments into a user-friendly message.
|
||||
|
||||
Args:
|
||||
args: The tool call arguments containing clarification details
|
||||
|
||||
Returns:
|
||||
Formatted message string
|
||||
"""
|
||||
question = args.get("question", "")
|
||||
clarification_type = args.get("clarification_type", "missing_info")
|
||||
context = args.get("context")
|
||||
options = args.get("options", [])
|
||||
|
||||
# Some models (e.g. Qwen3-Max) serialize array parameters as JSON strings
|
||||
# instead of native arrays. Deserialize and normalize so `options`
|
||||
# is always a list for the rendering logic below.
|
||||
if isinstance(options, str):
|
||||
try:
|
||||
options = json.loads(options)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
options = [options]
|
||||
|
||||
if options is None:
|
||||
options = []
|
||||
elif not isinstance(options, list):
|
||||
options = [options]
|
||||
|
||||
# Type-specific icons
|
||||
type_icons = {
|
||||
"missing_info": "❓",
|
||||
"ambiguous_requirement": "🤔",
|
||||
"approach_choice": "🔀",
|
||||
"risk_confirmation": "⚠️",
|
||||
"suggestion": "💡",
|
||||
}
|
||||
|
||||
icon = type_icons.get(clarification_type, "❓")
|
||||
|
||||
# Build the message naturally
|
||||
message_parts = []
|
||||
|
||||
# Add icon and question together for a more natural flow
|
||||
if context:
|
||||
# If there's context, present it first as background
|
||||
message_parts.append(f"{icon} {context}")
|
||||
message_parts.append(f"\n{question}")
|
||||
else:
|
||||
# Just the question with icon
|
||||
message_parts.append(f"{icon} {question}")
|
||||
|
||||
# Add options in a cleaner format
|
||||
if options and len(options) > 0:
|
||||
message_parts.append("") # blank line for spacing
|
||||
for i, option in enumerate(options, 1):
|
||||
message_parts.append(f" {i}. {option}")
|
||||
|
||||
return "\n".join(message_parts)
|
||||
|
||||
def _handle_clarification(self, request: ToolCallRequest) -> Command:
|
||||
"""Handle clarification request and return command to interrupt execution.
|
||||
|
||||
Args:
|
||||
request: Tool call request
|
||||
|
||||
Returns:
|
||||
Command that interrupts execution with the formatted clarification message
|
||||
"""
|
||||
# Extract clarification arguments
|
||||
args = request.tool_call.get("args", {})
|
||||
question = args.get("question", "")
|
||||
|
||||
logger.info("Intercepted clarification request")
|
||||
logger.debug("Clarification question: %s", question)
|
||||
|
||||
# Format the clarification message
|
||||
formatted_message = self._format_clarification_message(args)
|
||||
|
||||
# Get the tool call ID
|
||||
tool_call_id = request.tool_call.get("id", "")
|
||||
|
||||
# Create a ToolMessage with the formatted question
|
||||
# This will be added to the message history
|
||||
tool_message = ToolMessage(
|
||||
content=formatted_message,
|
||||
tool_call_id=tool_call_id,
|
||||
name="ask_clarification",
|
||||
)
|
||||
|
||||
# Return a Command that:
|
||||
# 1. Adds the formatted tool message
|
||||
# 2. Interrupts execution by going to __end__
|
||||
# Note: We don't add an extra AIMessage here - the frontend will detect
|
||||
# and display ask_clarification tool messages directly
|
||||
return Command(
|
||||
update={"messages": [tool_message]},
|
||||
goto=END,
|
||||
)
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept ask_clarification tool calls and interrupt execution (sync version).
|
||||
|
||||
Args:
|
||||
request: Tool call request
|
||||
handler: Original tool execution handler
|
||||
|
||||
Returns:
|
||||
Command that interrupts execution with the formatted clarification message
|
||||
"""
|
||||
# Check if this is an ask_clarification tool call
|
||||
if request.tool_call.get("name") != "ask_clarification":
|
||||
# Not a clarification call, execute normally
|
||||
return handler(request)
|
||||
|
||||
return self._handle_clarification(request)
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept ask_clarification tool calls and interrupt execution (async version).
|
||||
|
||||
Args:
|
||||
request: Tool call request
|
||||
handler: Original tool execution handler (async)
|
||||
|
||||
Returns:
|
||||
Command that interrupts execution with the formatted clarification message
|
||||
"""
|
||||
# Check if this is an ask_clarification tool call
|
||||
if request.tool_call.get("name") != "ask_clarification":
|
||||
# Not a clarification call, execute normally
|
||||
return await handler(request)
|
||||
|
||||
return self._handle_clarification(request)
|
||||
@@ -0,0 +1,110 @@
|
||||
"""Middleware to fix dangling tool calls in message history.
|
||||
|
||||
A dangling tool call occurs when an AIMessage contains tool_calls but there are
|
||||
no corresponding ToolMessages in the history (e.g., due to user interruption or
|
||||
request cancellation). This causes LLM errors due to incomplete message format.
|
||||
|
||||
This middleware intercepts the model call to detect and patch such gaps by
|
||||
inserting synthetic ToolMessages with an error indicator immediately after the
|
||||
AIMessage that made the tool calls, ensuring correct message ordering.
|
||||
|
||||
Note: Uses wrap_model_call instead of before_model to ensure patches are inserted
|
||||
at the correct positions (immediately after each dangling AIMessage), not appended
|
||||
to the end of the message list as before_model + add_messages reducer would do.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
|
||||
|
||||
Scans the message history for AIMessages whose tool_calls lack corresponding
|
||||
ToolMessages, and injects synthetic error responses immediately after the
|
||||
offending AIMessage so the LLM receives a well-formed conversation.
|
||||
"""
|
||||
|
||||
def _build_patched_messages(self, messages: list) -> list | None:
|
||||
"""Return a new message list with patches inserted at the correct positions.
|
||||
|
||||
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
|
||||
a synthetic ToolMessage is inserted immediately after that AIMessage.
|
||||
Returns None if no patches are needed.
|
||||
"""
|
||||
# Collect IDs of all existing ToolMessages
|
||||
existing_tool_msg_ids: set[str] = set()
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
existing_tool_msg_ids.add(msg.tool_call_id)
|
||||
|
||||
# Check if any patching is needed
|
||||
needs_patch = False
|
||||
for msg in messages:
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
for tc in getattr(msg, "tool_calls", None) or []:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids:
|
||||
needs_patch = True
|
||||
break
|
||||
if needs_patch:
|
||||
break
|
||||
|
||||
if not needs_patch:
|
||||
return None
|
||||
|
||||
# Build new list with patches inserted right after each dangling AIMessage
|
||||
patched: list = []
|
||||
patched_ids: set[str] = set()
|
||||
patch_count = 0
|
||||
for msg in messages:
|
||||
patched.append(msg)
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
for tc in getattr(msg, "tool_calls", None) or []:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
||||
patched.append(
|
||||
ToolMessage(
|
||||
content="[Tool call was interrupted and did not return a result.]",
|
||||
tool_call_id=tc_id,
|
||||
name=tc.get("name", "unknown"),
|
||||
status="error",
|
||||
)
|
||||
)
|
||||
patched_ids.add(tc_id)
|
||||
patch_count += 1
|
||||
|
||||
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||
return patched
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
patched = self._build_patched_messages(request.messages)
|
||||
if patched is not None:
|
||||
request = request.override(messages=patched)
|
||||
return handler(request)
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
patched = self._build_patched_messages(request.messages)
|
||||
if patched is not None:
|
||||
request = request.override(messages=patched)
|
||||
return await handler(request)
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Middleware to filter deferred tool schemas from model binding.
|
||||
|
||||
When tool_search is enabled, MCP tools are registered in the DeferredToolRegistry
|
||||
and passed to ToolNode for execution, but their schemas should NOT be sent to the
|
||||
LLM via bind_tools (that's the whole point of deferral — saving context tokens).
|
||||
|
||||
This middleware intercepts wrap_model_call and removes deferred tools from
|
||||
request.tools so that model.bind_tools only receives active tool schemas.
|
||||
The agent discovers deferred tools at runtime via the tool_search tool.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Remove deferred tools from request.tools before model binding.
|
||||
|
||||
ToolNode still holds all tools (including deferred) for execution routing,
|
||||
but the LLM only sees active tool schemas — deferred tools are discoverable
|
||||
via tool_search at runtime.
|
||||
"""
|
||||
|
||||
def _filter_tools(self, request: ModelRequest) -> ModelRequest:
|
||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||
|
||||
registry = get_deferred_registry()
|
||||
if not registry:
|
||||
return request
|
||||
|
||||
deferred_names = {e.name for e in registry.entries}
|
||||
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
|
||||
|
||||
if len(active_tools) < len(request.tools):
|
||||
logger.debug(f"Filtered {len(request.tools) - len(active_tools)} deferred tool schema(s) from model binding")
|
||||
|
||||
return request.override(tools=active_tools)
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(self._filter_tools(request))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
return await handler(self._filter_tools(request))
|
||||
@@ -0,0 +1,275 @@
|
||||
"""LLM error handling middleware with retry/backoff and user-facing fallbacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
|
||||
_BUSY_PATTERNS = (
|
||||
"server busy",
|
||||
"temporarily unavailable",
|
||||
"try again later",
|
||||
"please retry",
|
||||
"please try again",
|
||||
"overloaded",
|
||||
"high demand",
|
||||
"rate limit",
|
||||
"负载较高",
|
||||
"服务繁忙",
|
||||
"稍后重试",
|
||||
"请稍后重试",
|
||||
)
|
||||
_QUOTA_PATTERNS = (
|
||||
"insufficient_quota",
|
||||
"quota",
|
||||
"billing",
|
||||
"credit",
|
||||
"payment",
|
||||
"余额不足",
|
||||
"超出限额",
|
||||
"额度不足",
|
||||
"欠费",
|
||||
)
|
||||
_AUTH_PATTERNS = (
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"invalid api key",
|
||||
"invalid_api_key",
|
||||
"permission",
|
||||
"forbidden",
|
||||
"access denied",
|
||||
"无权",
|
||||
"未授权",
|
||||
)
|
||||
|
||||
|
||||
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Retry transient LLM errors and surface graceful assistant messages."""
|
||||
|
||||
retry_max_attempts: int = 3
|
||||
retry_base_delay_ms: int = 1000
|
||||
retry_cap_delay_ms: int = 8000
|
||||
|
||||
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
|
||||
detail = _extract_error_detail(exc)
|
||||
lowered = detail.lower()
|
||||
error_code = _extract_error_code(exc)
|
||||
status_code = _extract_status_code(exc)
|
||||
|
||||
if _matches_any(lowered, _QUOTA_PATTERNS) or _matches_any(str(error_code).lower(), _QUOTA_PATTERNS):
|
||||
return False, "quota"
|
||||
if _matches_any(lowered, _AUTH_PATTERNS):
|
||||
return False, "auth"
|
||||
|
||||
exc_name = exc.__class__.__name__
|
||||
if exc_name in {
|
||||
"APITimeoutError",
|
||||
"APIConnectionError",
|
||||
"InternalServerError",
|
||||
}:
|
||||
return True, "transient"
|
||||
if status_code in _RETRIABLE_STATUS_CODES:
|
||||
return True, "transient"
|
||||
if _matches_any(lowered, _BUSY_PATTERNS):
|
||||
return True, "busy"
|
||||
|
||||
return False, "generic"
|
||||
|
||||
def _build_retry_delay_ms(self, attempt: int, exc: BaseException) -> int:
|
||||
retry_after = _extract_retry_after_ms(exc)
|
||||
if retry_after is not None:
|
||||
return retry_after
|
||||
backoff = self.retry_base_delay_ms * (2 ** max(0, attempt - 1))
|
||||
return min(backoff, self.retry_cap_delay_ms)
|
||||
|
||||
def _build_retry_message(self, attempt: int, wait_ms: int, reason: str) -> str:
|
||||
seconds = max(1, round(wait_ms / 1000))
|
||||
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
|
||||
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
|
||||
|
||||
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
||||
detail = _extract_error_detail(exc)
|
||||
if reason == "quota":
|
||||
return "The configured LLM provider rejected the request because the account is out of quota, billing is unavailable, or usage is restricted. Please fix the provider account and try again."
|
||||
if reason == "auth":
|
||||
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
|
||||
if reason in {"busy", "transient"}:
|
||||
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
|
||||
return f"LLM request failed: {detail}"
|
||||
|
||||
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
|
||||
try:
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
writer(
|
||||
{
|
||||
"type": "llm_retry",
|
||||
"attempt": attempt,
|
||||
"max_attempts": self.retry_max_attempts,
|
||||
"wait_ms": wait_ms,
|
||||
"reason": reason,
|
||||
"message": self._build_retry_message(attempt, wait_ms, reason),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to emit llm_retry event", exc_info=True)
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
if retriable and attempt < self.retry_max_attempts:
|
||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||
logger.warning(
|
||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||
attempt,
|
||||
self.retry_max_attempts,
|
||||
wait_ms,
|
||||
_extract_error_detail(exc),
|
||||
)
|
||||
self._emit_retry_event(attempt, wait_ms, reason)
|
||||
time.sleep(wait_ms / 1000)
|
||||
attempt += 1
|
||||
continue
|
||||
logger.warning(
|
||||
"LLM call failed after %d attempt(s): %s",
|
||||
attempt,
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return await handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
if retriable and attempt < self.retry_max_attempts:
|
||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||
logger.warning(
|
||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||
attempt,
|
||||
self.retry_max_attempts,
|
||||
wait_ms,
|
||||
_extract_error_detail(exc),
|
||||
)
|
||||
self._emit_retry_event(attempt, wait_ms, reason)
|
||||
await asyncio.sleep(wait_ms / 1000)
|
||||
attempt += 1
|
||||
continue
|
||||
logger.warning(
|
||||
"LLM call failed after %d attempt(s): %s",
|
||||
attempt,
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
|
||||
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
|
||||
return any(pattern in detail for pattern in patterns)
|
||||
|
||||
|
||||
def _extract_error_code(exc: BaseException) -> Any:
|
||||
for attr in ("code", "error_code"):
|
||||
value = getattr(exc, attr, None)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
|
||||
body = getattr(exc, "body", None)
|
||||
if isinstance(body, dict):
|
||||
error = body.get("error")
|
||||
if isinstance(error, dict):
|
||||
for key in ("code", "type"):
|
||||
value = error.get(key)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _extract_status_code(exc: BaseException) -> int | None:
|
||||
for attr in ("status_code", "status"):
|
||||
value = getattr(exc, attr, None)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
response = getattr(exc, "response", None)
|
||||
status = getattr(response, "status_code", None)
|
||||
return status if isinstance(status, int) else None
|
||||
|
||||
|
||||
def _extract_retry_after_ms(exc: BaseException) -> int | None:
|
||||
response = getattr(exc, "response", None)
|
||||
headers = getattr(response, "headers", None)
|
||||
if headers is None:
|
||||
return None
|
||||
|
||||
raw = None
|
||||
header_name = ""
|
||||
for key in ("retry-after-ms", "Retry-After-Ms", "retry-after", "Retry-After"):
|
||||
header_name = key
|
||||
if hasattr(headers, "get"):
|
||||
raw = headers.get(key)
|
||||
if raw:
|
||||
break
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
try:
|
||||
multiplier = 1 if "ms" in header_name.lower() else 1000
|
||||
return max(0, int(float(raw) * multiplier))
|
||||
except (TypeError, ValueError):
|
||||
try:
|
||||
target = parsedate_to_datetime(str(raw))
|
||||
delta = target.timestamp() - time.time()
|
||||
return max(0, int(delta * 1000))
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
return None
|
||||
|
||||
|
||||
def _extract_error_detail(exc: BaseException) -> str:
|
||||
detail = str(exc).strip()
|
||||
if detail:
|
||||
return detail
|
||||
message = getattr(exc, "message", None)
|
||||
if isinstance(message, str) and message.strip():
|
||||
return message.strip()
|
||||
return exc.__class__.__name__
|
||||
@@ -0,0 +1,372 @@
|
||||
"""Middleware to detect and break repetitive tool call loops.
|
||||
|
||||
P0 safety: prevents the agent from calling the same tool with the same
|
||||
arguments indefinitely until the recursion limit kills the run.
|
||||
|
||||
Detection strategy:
|
||||
1. After each model response, hash the tool calls (name + args).
|
||||
2. Track recent hashes in a sliding window.
|
||||
3. If the same hash appears >= warn_threshold times, inject a
|
||||
"you are repeating yourself — wrap up" system message (once per hash).
|
||||
4. If it appears >= hard_limit times, strip all tool_calls from the
|
||||
response so the agent is forced to produce a final text answer.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defaults — can be overridden via constructor
|
||||
_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
|
||||
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
|
||||
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
||||
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
||||
_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type
|
||||
_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type
|
||||
|
||||
|
||||
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
|
||||
"""Normalize tool call args to a dict plus an optional fallback key.
|
||||
|
||||
Some providers serialize ``args`` as a JSON string instead of a dict.
|
||||
We defensively parse those cases so loop detection does not crash while
|
||||
still preserving a stable fallback key for non-dict payloads.
|
||||
"""
|
||||
if isinstance(raw_args, dict):
|
||||
return raw_args, None
|
||||
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
parsed = json.loads(raw_args)
|
||||
except (TypeError, ValueError, json.JSONDecodeError):
|
||||
return {}, raw_args
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
return parsed, None
|
||||
return {}, json.dumps(parsed, sort_keys=True, default=str)
|
||||
|
||||
if raw_args is None:
|
||||
return {}, None
|
||||
|
||||
return {}, json.dumps(raw_args, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _stable_tool_key(name: str, args: dict, fallback_key: str | None) -> str:
|
||||
"""Derive a stable key from salient args without overfitting to noise."""
|
||||
if name == "read_file" and fallback_key is None:
|
||||
path = args.get("path") or ""
|
||||
start_line = args.get("start_line")
|
||||
end_line = args.get("end_line")
|
||||
|
||||
bucket_size = 200
|
||||
try:
|
||||
start_line = int(start_line) if start_line is not None else 1
|
||||
except (TypeError, ValueError):
|
||||
start_line = 1
|
||||
try:
|
||||
end_line = int(end_line) if end_line is not None else start_line
|
||||
except (TypeError, ValueError):
|
||||
end_line = start_line
|
||||
|
||||
start_line, end_line = sorted((start_line, end_line))
|
||||
bucket_start = max(start_line, 1)
|
||||
bucket_end = max(end_line, 1)
|
||||
bucket_start = (bucket_start - 1) // bucket_size
|
||||
bucket_end = (bucket_end - 1) // bucket_size
|
||||
return f"{path}:{bucket_start}-{bucket_end}"
|
||||
|
||||
# write_file / str_replace are content-sensitive: same path may be updated
|
||||
# with different payloads during iteration. Using only salient fields (path)
|
||||
# can collapse distinct calls, so we hash full args to reduce false positives.
|
||||
if name in {"write_file", "str_replace"}:
|
||||
if fallback_key is not None:
|
||||
return fallback_key
|
||||
return json.dumps(args, sort_keys=True, default=str)
|
||||
|
||||
salient_fields = ("path", "url", "query", "command", "pattern", "glob", "cmd")
|
||||
stable_args = {field: args[field] for field in salient_fields if args.get(field) is not None}
|
||||
if stable_args:
|
||||
return json.dumps(stable_args, sort_keys=True, default=str)
|
||||
|
||||
if fallback_key is not None:
|
||||
return fallback_key
|
||||
|
||||
return json.dumps(args, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _hash_tool_calls(tool_calls: list[dict]) -> str:
|
||||
"""Deterministic hash of a set of tool calls (name + stable key).
|
||||
|
||||
This is intended to be order-independent: the same multiset of tool calls
|
||||
should always produce the same hash, regardless of their input order.
|
||||
"""
|
||||
# Normalize each tool call to a stable (name, key) structure.
|
||||
normalized: list[str] = []
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
args, fallback_key = _normalize_tool_call_args(tc.get("args", {}))
|
||||
key = _stable_tool_key(name, args, fallback_key)
|
||||
|
||||
normalized.append(f"{name}:{key}")
|
||||
|
||||
# Sort so permutations of the same multiset of calls yield the same ordering.
|
||||
normalized.sort()
|
||||
blob = json.dumps(normalized, sort_keys=True, default=str)
|
||||
return hashlib.md5(blob.encode()).hexdigest()[:12]
|
||||
|
||||
|
||||
_WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
||||
|
||||
_TOOL_FREQ_WARNING_MSG = (
|
||||
"[LOOP DETECTED] You have called {tool_name} {count} times without producing a final answer. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
||||
)
|
||||
|
||||
_HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far."
|
||||
|
||||
_TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times — exceeded the per-tool safety limit. Producing final answer with results collected so far."
|
||||
|
||||
|
||||
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Detects and breaks repetitive tool call loops.
|
||||
|
||||
Args:
|
||||
warn_threshold: Number of identical tool call sets before injecting
|
||||
a warning message. Default: 3.
|
||||
hard_limit: Number of identical tool call sets before stripping
|
||||
tool_calls entirely. Default: 5.
|
||||
window_size: Size of the sliding window for tracking calls.
|
||||
Default: 20.
|
||||
max_tracked_threads: Maximum number of threads to track before
|
||||
evicting the least recently used. Default: 100.
|
||||
tool_freq_warn: Number of calls to the same tool *type* (regardless
|
||||
of arguments) before injecting a frequency warning. Catches
|
||||
cross-file read loops that hash-based detection misses.
|
||||
Default: 30.
|
||||
tool_freq_hard_limit: Number of calls to the same tool type before
|
||||
forcing a stop. Default: 50.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
warn_threshold: int = _DEFAULT_WARN_THRESHOLD,
|
||||
hard_limit: int = _DEFAULT_HARD_LIMIT,
|
||||
window_size: int = _DEFAULT_WINDOW_SIZE,
|
||||
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
|
||||
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
|
||||
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
|
||||
):
|
||||
super().__init__()
|
||||
self.warn_threshold = warn_threshold
|
||||
self.hard_limit = hard_limit
|
||||
self.window_size = window_size
|
||||
self.max_tracked_threads = max_tracked_threads
|
||||
self.tool_freq_warn = tool_freq_warn
|
||||
self.tool_freq_hard_limit = tool_freq_hard_limit
|
||||
self._lock = threading.Lock()
|
||||
# Per-thread tracking using OrderedDict for LRU eviction
|
||||
self._history: OrderedDict[str, list[str]] = OrderedDict()
|
||||
self._warned: dict[str, set[str]] = defaultdict(set)
|
||||
# Per-thread, per-tool-type cumulative call counts
|
||||
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
def _get_thread_id(self, runtime: Runtime) -> str:
|
||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id:
|
||||
return thread_id
|
||||
return "default"
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict least recently used threads if over the limit.
|
||||
|
||||
Must be called while holding self._lock.
|
||||
"""
|
||||
while len(self._history) > self.max_tracked_threads:
|
||||
evicted_id, _ = self._history.popitem(last=False)
|
||||
self._warned.pop(evicted_id, None)
|
||||
self._tool_freq.pop(evicted_id, None)
|
||||
self._tool_freq_warned.pop(evicted_id, None)
|
||||
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
|
||||
|
||||
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
|
||||
"""Track tool calls and check for loops.
|
||||
|
||||
Two detection layers:
|
||||
1. **Hash-based** (existing): catches identical tool call sets.
|
||||
2. **Frequency-based** (new): catches the same *tool type* being
|
||||
called many times with varying arguments (e.g. ``read_file``
|
||||
on 40 different files).
|
||||
|
||||
Returns:
|
||||
(warning_message_or_none, should_hard_stop)
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None, False
|
||||
|
||||
last_msg = messages[-1]
|
||||
if getattr(last_msg, "type", None) != "ai":
|
||||
return None, False
|
||||
|
||||
tool_calls = getattr(last_msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
return None, False
|
||||
|
||||
thread_id = self._get_thread_id(runtime)
|
||||
call_hash = _hash_tool_calls(tool_calls)
|
||||
|
||||
with self._lock:
|
||||
# Touch / create entry (move to end for LRU)
|
||||
if thread_id in self._history:
|
||||
self._history.move_to_end(thread_id)
|
||||
else:
|
||||
self._history[thread_id] = []
|
||||
self._evict_if_needed()
|
||||
|
||||
history = self._history[thread_id]
|
||||
history.append(call_hash)
|
||||
if len(history) > self.window_size:
|
||||
history[:] = history[-self.window_size :]
|
||||
|
||||
count = history.count(call_hash)
|
||||
tool_names = [tc.get("name", "?") for tc in tool_calls]
|
||||
|
||||
# --- Layer 1: hash-based (identical call sets) ---
|
||||
if count >= self.hard_limit:
|
||||
logger.error(
|
||||
"Loop hard limit reached — forcing stop",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"call_hash": call_hash,
|
||||
"count": count,
|
||||
"tools": tool_names,
|
||||
},
|
||||
)
|
||||
return _HARD_STOP_MSG, True
|
||||
|
||||
if count >= self.warn_threshold:
|
||||
warned = self._warned[thread_id]
|
||||
if call_hash not in warned:
|
||||
warned.add(call_hash)
|
||||
logger.warning(
|
||||
"Repetitive tool calls detected — injecting warning",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"call_hash": call_hash,
|
||||
"count": count,
|
||||
"tools": tool_names,
|
||||
},
|
||||
)
|
||||
return _WARNING_MSG, False
|
||||
|
||||
# --- Layer 2: per-tool-type frequency ---
|
||||
freq = self._tool_freq[thread_id]
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
if not name:
|
||||
continue
|
||||
freq[name] += 1
|
||||
tc_count = freq[name]
|
||||
|
||||
if tc_count >= self.tool_freq_hard_limit:
|
||||
logger.error(
|
||||
"Tool frequency hard limit reached — forcing stop",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"tool_name": name,
|
||||
"count": tc_count,
|
||||
},
|
||||
)
|
||||
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
|
||||
|
||||
if tc_count >= self.tool_freq_warn:
|
||||
warned = self._tool_freq_warned[thread_id]
|
||||
if name not in warned:
|
||||
warned.add(name)
|
||||
logger.warning(
|
||||
"Tool frequency warning — too many calls to same tool type",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"tool_name": name,
|
||||
"count": tc_count,
|
||||
},
|
||||
)
|
||||
return _TOOL_FREQ_WARNING_MSG.format(tool_name=name, count=tc_count), False
|
||||
|
||||
return None, False
|
||||
|
||||
@staticmethod
|
||||
def _append_text(content: str | list | None, text: str) -> str | list:
|
||||
"""Append *text* to AIMessage content, handling str, list, and None.
|
||||
|
||||
When content is a list of content blocks (e.g. Anthropic thinking mode),
|
||||
we append a new ``{"type": "text", ...}`` block instead of concatenating
|
||||
a string to a list, which would raise ``TypeError``.
|
||||
"""
|
||||
if content is None:
|
||||
return text
|
||||
if isinstance(content, list):
|
||||
return [*content, {"type": "text", "text": f"\n\n{text}"}]
|
||||
if isinstance(content, str):
|
||||
return content + f"\n\n{text}"
|
||||
# Fallback: coerce unexpected types to str to avoid TypeError
|
||||
return str(content) + f"\n\n{text}"
|
||||
|
||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
warning, hard_stop = self._track_and_check(state, runtime)
|
||||
|
||||
if hard_stop:
|
||||
# Strip tool_calls from the last AIMessage to force text output
|
||||
messages = state.get("messages", [])
|
||||
last_msg = messages[-1]
|
||||
stripped_msg = last_msg.model_copy(
|
||||
update={
|
||||
"tool_calls": [],
|
||||
"content": self._append_text(last_msg.content, warning),
|
||||
}
|
||||
)
|
||||
return {"messages": [stripped_msg]}
|
||||
|
||||
if warning:
|
||||
# Inject as HumanMessage instead of SystemMessage to avoid
|
||||
# Anthropic's "multiple non-consecutive system messages" error.
|
||||
# Anthropic models require system messages only at the start of
|
||||
# the conversation; injecting one mid-conversation crashes
|
||||
# langchain_anthropic's _format_messages(). HumanMessage works
|
||||
# with all providers. See #1299.
|
||||
return {"messages": [HumanMessage(content=warning)]}
|
||||
|
||||
return None
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
def reset(self, thread_id: str | None = None) -> None:
|
||||
"""Clear tracking state. If thread_id given, clear only that thread."""
|
||||
with self._lock:
|
||||
if thread_id:
|
||||
self._history.pop(thread_id, None)
|
||||
self._warned.pop(thread_id, None)
|
||||
self._tool_freq.pop(thread_id, None)
|
||||
self._tool_freq_warned.pop(thread_id, None)
|
||||
else:
|
||||
self._history.clear()
|
||||
self._warned.clear()
|
||||
self._tool_freq.clear()
|
||||
self._tool_freq_warned.clear()
|
||||
@@ -0,0 +1,248 @@
|
||||
"""Middleware for memory mechanism."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||
_CORRECTION_PATTERNS = (
|
||||
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
|
||||
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
|
||||
re.compile(r"\btry again\b", re.IGNORECASE),
|
||||
re.compile(r"\bredo\b", re.IGNORECASE),
|
||||
re.compile(r"不对"),
|
||||
re.compile(r"你理解错了"),
|
||||
re.compile(r"你理解有误"),
|
||||
re.compile(r"重试"),
|
||||
re.compile(r"重新来"),
|
||||
re.compile(r"换一种"),
|
||||
re.compile(r"改用"),
|
||||
)
|
||||
|
||||
_REINFORCEMENT_PATTERNS = (
|
||||
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
|
||||
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
|
||||
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
|
||||
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
|
||||
re.compile(r"完全正确(?:[。!?!?.]|$)"),
|
||||
re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
|
||||
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
|
||||
re.compile(r"继续保持(?:[。!?!?.]|$)"),
|
||||
)
|
||||
|
||||
|
||||
class MemoryMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _extract_message_text(message: Any) -> str:
|
||||
"""Extract plain text from message content for filtering and signal detection."""
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
text_parts.append(part)
|
||||
elif isinstance(part, dict):
|
||||
text_val = part.get("text")
|
||||
if isinstance(text_val, str):
|
||||
text_parts.append(text_val)
|
||||
return " ".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
"""Filter messages to keep only user inputs and final assistant responses.
|
||||
|
||||
This filters out:
|
||||
- Tool messages (intermediate tool call results)
|
||||
- AI messages with tool_calls (intermediate steps, not final responses)
|
||||
- The <uploaded_files> block injected by UploadsMiddleware into human messages
|
||||
(file paths are session-scoped and must not persist in long-term memory).
|
||||
The user's actual question is preserved; only turns whose content is entirely
|
||||
the upload block (nothing remains after stripping) are dropped along with
|
||||
their paired assistant response.
|
||||
|
||||
Only keeps:
|
||||
- Human messages (with the ephemeral upload block removed)
|
||||
- AI messages without tool_calls (final assistant responses), unless the
|
||||
paired human turn was upload-only and had no real user text.
|
||||
|
||||
Args:
|
||||
messages: List of all conversation messages.
|
||||
|
||||
Returns:
|
||||
Filtered list containing only user inputs and final assistant responses.
|
||||
"""
|
||||
filtered = []
|
||||
skip_next_ai = False
|
||||
for msg in messages:
|
||||
msg_type = getattr(msg, "type", None)
|
||||
|
||||
if msg_type == "human":
|
||||
content_str = _extract_message_text(msg)
|
||||
if "<uploaded_files>" in content_str:
|
||||
# Strip the ephemeral upload block; keep the user's real question.
|
||||
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
||||
if not stripped:
|
||||
# Nothing left — the entire turn was upload bookkeeping;
|
||||
# skip it and the paired assistant response.
|
||||
skip_next_ai = True
|
||||
continue
|
||||
# Rebuild the message with cleaned content so the user's question
|
||||
# is still available for memory summarisation.
|
||||
from copy import copy
|
||||
|
||||
clean_msg = copy(msg)
|
||||
clean_msg.content = stripped
|
||||
filtered.append(clean_msg)
|
||||
skip_next_ai = False
|
||||
else:
|
||||
filtered.append(msg)
|
||||
skip_next_ai = False
|
||||
elif msg_type == "ai":
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
if skip_next_ai:
|
||||
skip_next_ai = False
|
||||
continue
|
||||
filtered.append(msg)
|
||||
# Skip tool messages and AI messages with tool_calls
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def detect_correction(messages: list[Any]) -> bool:
|
||||
"""Detect explicit user corrections in recent conversation turns.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale corrections from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def detect_reinforcement(messages: list[Any]) -> bool:
|
||||
"""Detect explicit positive reinforcement signals in recent conversation turns.
|
||||
|
||||
Complements detect_correction() by identifying when the user confirms the
|
||||
agent's approach was correct. This allows the memory system to record what
|
||||
worked well, not just what went wrong.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale signals from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
"""Middleware that queues conversation for memory update after agent execution.
|
||||
|
||||
This middleware:
|
||||
1. After each agent execution, queues the conversation for memory update
|
||||
2. Only includes user inputs and final assistant responses (ignores tool calls)
|
||||
3. The queue uses debouncing to batch multiple updates together
|
||||
4. Memory is updated asynchronously via LLM summarization
|
||||
"""
|
||||
|
||||
state_schema = MemoryMiddlewareState
|
||||
|
||||
def __init__(self, agent_name: str | None = None):
|
||||
"""Initialize the MemoryMiddleware.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
"""
|
||||
super().__init__()
|
||||
self._agent_name = agent_name
|
||||
|
||||
@override
|
||||
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Queue conversation for memory update after agent completes.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The runtime context.
|
||||
|
||||
Returns:
|
||||
None (no state changes needed from this middleware).
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
return None
|
||||
|
||||
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id is None:
|
||||
config_data = get_config()
|
||||
thread_id = config_data.get("configurable", {}).get("thread_id")
|
||||
if not thread_id:
|
||||
logger.debug("No thread_id in context, skipping memory update")
|
||||
return None
|
||||
|
||||
# Get messages from state
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
logger.debug("No messages in state, skipping memory update")
|
||||
return None
|
||||
|
||||
# Filter to only keep user inputs and final assistant responses
|
||||
filtered_messages = _filter_messages_for_memory(messages)
|
||||
|
||||
# Only queue if there's meaningful conversation
|
||||
# At minimum need one user message and one assistant response
|
||||
user_messages = [m for m in filtered_messages if getattr(m, "type", None) == "human"]
|
||||
assistant_messages = [m for m in filtered_messages if getattr(m, "type", None) == "ai"]
|
||||
|
||||
if not user_messages or not assistant_messages:
|
||||
return None
|
||||
|
||||
# Queue the filtered conversation for memory update
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
queue = get_memory_queue()
|
||||
queue.add(
|
||||
thread_id=thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=self._agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,363 @@
|
||||
"""SandboxAuditMiddleware - bash command security auditing."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import shlex
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import override
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command classification rules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Each pattern is compiled once at import time.
|
||||
_HIGH_RISK_PATTERNS: list[re.Pattern[str]] = [
|
||||
# --- original rules (retained) ---
|
||||
re.compile(r"rm\s+-[^\s]*r[^\s]*\s+(/\*?|~/?\*?|/home\b|/root\b)\s*$"),
|
||||
re.compile(r"dd\s+if="),
|
||||
re.compile(r"mkfs"),
|
||||
re.compile(r"cat\s+/etc/shadow"),
|
||||
re.compile(r">+\s*/etc/"),
|
||||
# --- pipe to sh/bash (generalised, replaces old curl|sh rule) ---
|
||||
re.compile(r"\|\s*(ba)?sh\b"),
|
||||
# --- command substitution (targeted – only dangerous executables) ---
|
||||
re.compile(r"[`$]\(?\s*(curl|wget|bash|sh|python|ruby|perl|base64)"),
|
||||
# --- base64 decode piped to execution ---
|
||||
re.compile(r"base64\s+.*-d.*\|"),
|
||||
# --- overwrite system binaries ---
|
||||
re.compile(r">+\s*(/usr/bin/|/bin/|/sbin/)"),
|
||||
# --- overwrite shell startup files ---
|
||||
re.compile(r">+\s*~/?\.(bashrc|profile|zshrc|bash_profile)"),
|
||||
# --- process environment leakage ---
|
||||
re.compile(r"/proc/[^/]+/environ"),
|
||||
# --- dynamic linker hijack (one-step escalation) ---
|
||||
re.compile(r"\b(LD_PRELOAD|LD_LIBRARY_PATH)\s*="),
|
||||
# --- bash built-in networking (bypasses tool allowlists) ---
|
||||
re.compile(r"/dev/tcp/"),
|
||||
# --- fork bomb ---
|
||||
re.compile(r"\S+\(\)\s*\{[^}]*\|\s*\S+\s*&"), # :(){ :|:& };:
|
||||
re.compile(r"while\s+true.*&\s*done"), # while true; do bash & done
|
||||
]
|
||||
|
||||
_MEDIUM_RISK_PATTERNS: list[re.Pattern[str]] = [
|
||||
re.compile(r"chmod\s+777"),
|
||||
re.compile(r"pip3?\s+install"),
|
||||
re.compile(r"apt(-get)?\s+install"),
|
||||
# sudo/su: no-op under Docker root; warn so LLM is aware
|
||||
re.compile(r"\b(sudo|su)\b"),
|
||||
# PATH modification: long attack chain, warn rather than block
|
||||
re.compile(r"\bPATH\s*="),
|
||||
]
|
||||
|
||||
|
||||
def _split_compound_command(command: str) -> list[str]:
|
||||
"""Split a compound command into sub-commands (quote-aware).
|
||||
|
||||
Scans the raw command string so unquoted shell control operators are
|
||||
recognised even when they are not surrounded by whitespace
|
||||
(e.g. ``safe;rm -rf /`` or ``rm -rf /&&echo ok``). Operators inside
|
||||
quotes are ignored. If the command ends with an unclosed quote or a
|
||||
dangling escape, return the whole command unchanged (fail-closed —
|
||||
safer to classify the unsplit string than silently drop parts).
|
||||
"""
|
||||
parts: list[str] = []
|
||||
current: list[str] = []
|
||||
in_single_quote = False
|
||||
in_double_quote = False
|
||||
escaping = False
|
||||
index = 0
|
||||
|
||||
while index < len(command):
|
||||
char = command[index]
|
||||
|
||||
if escaping:
|
||||
current.append(char)
|
||||
escaping = False
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if char == "\\" and not in_single_quote:
|
||||
current.append(char)
|
||||
escaping = True
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if char == "'" and not in_double_quote:
|
||||
in_single_quote = not in_single_quote
|
||||
current.append(char)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if char == '"' and not in_single_quote:
|
||||
in_double_quote = not in_double_quote
|
||||
current.append(char)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if not in_single_quote and not in_double_quote:
|
||||
if command.startswith("&&", index) or command.startswith("||", index):
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
current = []
|
||||
index += 2
|
||||
continue
|
||||
if char == ";":
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
current = []
|
||||
index += 1
|
||||
continue
|
||||
|
||||
current.append(char)
|
||||
index += 1
|
||||
|
||||
# Unclosed quote or dangling escape → fail-closed, return whole command
|
||||
if in_single_quote or in_double_quote or escaping:
|
||||
return [command]
|
||||
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
return parts if parts else [command]
|
||||
|
||||
|
||||
def _classify_single_command(command: str) -> str:
|
||||
"""Classify a single (non-compound) command. Return 'block', 'warn', or 'pass'."""
|
||||
normalized = " ".join(command.split())
|
||||
|
||||
for pattern in _HIGH_RISK_PATTERNS:
|
||||
if pattern.search(normalized):
|
||||
return "block"
|
||||
|
||||
# Also try shlex-parsed tokens for high-risk detection
|
||||
try:
|
||||
tokens = shlex.split(command)
|
||||
joined = " ".join(tokens)
|
||||
for pattern in _HIGH_RISK_PATTERNS:
|
||||
if pattern.search(joined):
|
||||
return "block"
|
||||
except ValueError:
|
||||
# shlex.split fails on unclosed quotes — treat as suspicious
|
||||
return "block"
|
||||
|
||||
for pattern in _MEDIUM_RISK_PATTERNS:
|
||||
if pattern.search(normalized):
|
||||
return "warn"
|
||||
|
||||
return "pass"
|
||||
|
||||
|
||||
def _classify_command(command: str) -> str:
|
||||
"""Return 'block', 'warn', or 'pass'.
|
||||
|
||||
Strategy:
|
||||
1. First scan the *whole* raw command against high-risk patterns. This
|
||||
catches structural attacks like ``while true; do bash & done`` or
|
||||
``:(){ :|:& };:`` that span multiple shell statements — splitting them
|
||||
on ``;`` would destroy the pattern context.
|
||||
2. Then split compound commands (e.g. ``cmd1 && cmd2 ; cmd3``) and
|
||||
classify each sub-command independently. The most severe verdict wins.
|
||||
"""
|
||||
# Pass 1: whole-command high-risk scan (catches multi-statement patterns)
|
||||
normalized = " ".join(command.split())
|
||||
for pattern in _HIGH_RISK_PATTERNS:
|
||||
if pattern.search(normalized):
|
||||
return "block"
|
||||
|
||||
# Pass 2: per-sub-command classification
|
||||
sub_commands = _split_compound_command(command)
|
||||
worst = "pass"
|
||||
for sub in sub_commands:
|
||||
verdict = _classify_single_command(sub)
|
||||
if verdict == "block":
|
||||
return "block" # short-circuit: can't get worse
|
||||
if verdict == "warn":
|
||||
worst = "warn"
|
||||
return worst
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
|
||||
"""Bash command security auditing middleware.
|
||||
|
||||
For every ``bash`` tool call:
|
||||
1. **Command classification**: regex + shlex analysis grades commands as
|
||||
high-risk (block), medium-risk (warn), or safe (pass).
|
||||
2. **Audit log**: every bash call is recorded as a structured JSON entry
|
||||
via the standard logger (visible in langgraph.log).
|
||||
|
||||
High-risk commands (e.g. ``rm -rf /``, ``curl url | bash``) are blocked:
|
||||
the handler is not called and an error ``ToolMessage`` is returned so the
|
||||
agent loop can continue gracefully.
|
||||
|
||||
Medium-risk commands (e.g. ``pip install``, ``chmod 777``) are executed
|
||||
normally; a warning is appended to the tool result so the LLM is aware.
|
||||
"""
|
||||
|
||||
state_schema = ThreadState
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_thread_id(self, request: ToolCallRequest) -> str | None:
|
||||
runtime = request.runtime # ToolRuntime; may be None-like in tests
|
||||
if runtime is None:
|
||||
return None
|
||||
ctx = getattr(runtime, "context", None) or {}
|
||||
thread_id = ctx.get("thread_id") if isinstance(ctx, dict) else None
|
||||
if thread_id is None:
|
||||
cfg = getattr(runtime, "config", None) or {}
|
||||
thread_id = cfg.get("configurable", {}).get("thread_id")
|
||||
return thread_id
|
||||
|
||||
_AUDIT_COMMAND_LIMIT = 200
|
||||
|
||||
def _write_audit(self, thread_id: str | None, command: str, verdict: str, *, truncate: bool = False) -> None:
|
||||
audited_command = command
|
||||
if truncate and len(command) > self._AUDIT_COMMAND_LIMIT:
|
||||
audited_command = f"{command[: self._AUDIT_COMMAND_LIMIT]}... ({len(command)} chars)"
|
||||
record = {
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"thread_id": thread_id or "unknown",
|
||||
"command": audited_command,
|
||||
"verdict": verdict,
|
||||
}
|
||||
logger.info("[SandboxAudit] %s", json.dumps(record, ensure_ascii=False))
|
||||
|
||||
def _build_block_message(self, request: ToolCallRequest, reason: str) -> ToolMessage:
|
||||
tool_call_id = str(request.tool_call.get("id") or "missing_id")
|
||||
return ToolMessage(
|
||||
content=f"Command blocked: {reason}. Please use a safer alternative approach.",
|
||||
tool_call_id=tool_call_id,
|
||||
name="bash",
|
||||
status="error",
|
||||
)
|
||||
|
||||
def _append_warn_to_result(self, result: ToolMessage | Command, command: str) -> ToolMessage | Command:
|
||||
"""Append a warning note to the tool result for medium-risk commands."""
|
||||
if not isinstance(result, ToolMessage):
|
||||
return result
|
||||
warning = f"\n\n⚠️ Warning: `{command}` is a medium-risk command that may modify the runtime environment."
|
||||
if isinstance(result.content, list):
|
||||
new_content = list(result.content) + [{"type": "text", "text": warning}]
|
||||
else:
|
||||
new_content = str(result.content) + warning
|
||||
return ToolMessage(
|
||||
content=new_content,
|
||||
tool_call_id=result.tool_call_id,
|
||||
name=result.name,
|
||||
status=result.status,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Input sanitisation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# Normal bash commands rarely exceed a few hundred characters. 10 000 is
|
||||
# well above any legitimate use case yet a tiny fraction of Linux ARG_MAX.
|
||||
# Anything longer is almost certainly a payload injection or base64-encoded
|
||||
# attack string.
|
||||
_MAX_COMMAND_LENGTH = 10_000
|
||||
|
||||
def _validate_input(self, command: str) -> str | None:
|
||||
"""Return ``None`` if *command* is acceptable, else a rejection reason."""
|
||||
if not command.strip():
|
||||
return "empty command"
|
||||
if len(command) > self._MAX_COMMAND_LENGTH:
|
||||
return "command too long"
|
||||
if "\x00" in command:
|
||||
return "null byte detected"
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core logic (shared between sync and async paths)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _pre_process(self, request: ToolCallRequest) -> tuple[str, str | None, str, str | None]:
|
||||
"""
|
||||
Returns (command, thread_id, verdict, reject_reason).
|
||||
verdict is 'block', 'warn', or 'pass'.
|
||||
reject_reason is non-None only for input sanitisation rejections.
|
||||
"""
|
||||
args = request.tool_call.get("args", {})
|
||||
raw_command = args.get("command")
|
||||
command = raw_command if isinstance(raw_command, str) else ""
|
||||
thread_id = self._get_thread_id(request)
|
||||
|
||||
# ① input sanitisation — reject malformed input before regex analysis
|
||||
reject_reason = self._validate_input(command)
|
||||
if reject_reason:
|
||||
self._write_audit(thread_id, command, "block", truncate=True)
|
||||
logger.warning("[SandboxAudit] INVALID INPUT thread=%s reason=%s", thread_id, reject_reason)
|
||||
return command, thread_id, "block", reject_reason
|
||||
|
||||
# ② classify command
|
||||
verdict = _classify_command(command)
|
||||
|
||||
# ③ audit log
|
||||
self._write_audit(thread_id, command, verdict)
|
||||
|
||||
if verdict == "block":
|
||||
logger.warning("[SandboxAudit] BLOCKED thread=%s cmd=%r", thread_id, command)
|
||||
elif verdict == "warn":
|
||||
logger.warning("[SandboxAudit] WARN (medium-risk) thread=%s cmd=%r", thread_id, command)
|
||||
|
||||
return command, thread_id, verdict, None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# wrap_tool_call hooks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
if request.tool_call.get("name") != "bash":
|
||||
return handler(request)
|
||||
|
||||
command, _, verdict, reject_reason = self._pre_process(request)
|
||||
if verdict == "block":
|
||||
reason = reject_reason or "security violation detected"
|
||||
return self._build_block_message(request, reason)
|
||||
result = handler(request)
|
||||
if verdict == "warn":
|
||||
result = self._append_warn_to_result(result, command)
|
||||
return result
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
if request.tool_call.get("name") != "bash":
|
||||
return await handler(request)
|
||||
|
||||
command, _, verdict, reject_reason = self._pre_process(request)
|
||||
if verdict == "block":
|
||||
reason = reject_reason or "security violation detected"
|
||||
return self._build_block_message(request, reason)
|
||||
result = await handler(request)
|
||||
if verdict == "warn":
|
||||
result = self._append_warn_to_result(result, command)
|
||||
return result
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Middleware to enforce maximum concurrent subagent tool calls per model response."""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Valid range for max_concurrent_subagents
|
||||
MIN_SUBAGENT_LIMIT = 2
|
||||
MAX_SUBAGENT_LIMIT = 4
|
||||
|
||||
|
||||
def _clamp_subagent_limit(value: int) -> int:
|
||||
"""Clamp subagent limit to valid range [2, 4]."""
|
||||
return max(MIN_SUBAGENT_LIMIT, min(MAX_SUBAGENT_LIMIT, value))
|
||||
|
||||
|
||||
class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Truncates excess 'task' tool calls from a single model response.
|
||||
|
||||
When an LLM generates more than max_concurrent parallel task tool calls
|
||||
in one response, this middleware keeps only the first max_concurrent and
|
||||
discards the rest. This is more reliable than prompt-based limits.
|
||||
|
||||
Args:
|
||||
max_concurrent: Maximum number of concurrent subagent calls allowed.
|
||||
Defaults to MAX_CONCURRENT_SUBAGENTS (3). Clamped to [2, 4].
|
||||
"""
|
||||
|
||||
def __init__(self, max_concurrent: int = MAX_CONCURRENT_SUBAGENTS):
|
||||
super().__init__()
|
||||
self.max_concurrent = _clamp_subagent_limit(max_concurrent)
|
||||
|
||||
def _truncate_task_calls(self, state: AgentState) -> dict | None:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_msg = messages[-1]
|
||||
if getattr(last_msg, "type", None) != "ai":
|
||||
return None
|
||||
|
||||
tool_calls = getattr(last_msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
return None
|
||||
|
||||
# Count task tool calls
|
||||
task_indices = [i for i, tc in enumerate(tool_calls) if tc.get("name") == "task"]
|
||||
if len(task_indices) <= self.max_concurrent:
|
||||
return None
|
||||
|
||||
# Build set of indices to drop (excess task calls beyond the limit)
|
||||
indices_to_drop = set(task_indices[self.max_concurrent :])
|
||||
truncated_tool_calls = [tc for i, tc in enumerate(tool_calls) if i not in indices_to_drop]
|
||||
|
||||
dropped_count = len(indices_to_drop)
|
||||
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
|
||||
|
||||
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
|
||||
return {"messages": [updated_msg]}
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._truncate_task_calls(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._truncate_task_calls(state)
|
||||
@@ -0,0 +1,99 @@
|
||||
import logging
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThreadDataMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
|
||||
|
||||
class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
"""Create thread data directories for each thread execution.
|
||||
|
||||
Creates the following directory structure:
|
||||
- {base_dir}/threads/{thread_id}/user-data/workspace
|
||||
- {base_dir}/threads/{thread_id}/user-data/uploads
|
||||
- {base_dir}/threads/{thread_id}/user-data/outputs
|
||||
|
||||
Lifecycle Management:
|
||||
- With lazy_init=True (default): Only compute paths, directories created on-demand
|
||||
- With lazy_init=False: Eagerly create directories in before_agent()
|
||||
"""
|
||||
|
||||
state_schema = ThreadDataMiddlewareState
|
||||
|
||||
def __init__(self, base_dir: str | None = None, lazy_init: bool = True):
|
||||
"""Initialize the middleware.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for thread data. Defaults to Paths resolution.
|
||||
lazy_init: If True, defer directory creation until needed.
|
||||
If False, create directories eagerly in before_agent().
|
||||
Default is True for optimal performance.
|
||||
"""
|
||||
super().__init__()
|
||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||
self._lazy_init = lazy_init
|
||||
|
||||
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
|
||||
"""Get the paths for a thread's data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
|
||||
Returns:
|
||||
Dictionary with workspace_path, uploads_path, and outputs_path.
|
||||
"""
|
||||
return {
|
||||
"workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
|
||||
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
|
||||
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
|
||||
}
|
||||
|
||||
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
|
||||
"""Create the thread data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
|
||||
Returns:
|
||||
Dictionary with the created directory paths.
|
||||
"""
|
||||
self._paths.ensure_thread_dirs(thread_id)
|
||||
return self._get_thread_paths(thread_id)
|
||||
|
||||
@override
|
||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
context = runtime.context or {}
|
||||
thread_id = context.get("thread_id")
|
||||
if thread_id is None:
|
||||
config = get_config()
|
||||
thread_id = config.get("configurable", {}).get("thread_id")
|
||||
|
||||
if thread_id is None:
|
||||
raise ValueError("Thread ID is required in runtime context or config.configurable")
|
||||
|
||||
if self._lazy_init:
|
||||
# Lazy initialization: only compute paths, don't create directories
|
||||
paths = self._get_thread_paths(thread_id)
|
||||
else:
|
||||
# Eager initialization: create directories immediately
|
||||
paths = self._create_thread_directories(thread_id)
|
||||
logger.debug("Created thread data directories for thread %s", thread_id)
|
||||
|
||||
return {
|
||||
"thread_data": {
|
||||
**paths,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
"""Middleware for automatic thread title generation."""
|
||||
|
||||
import logging
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.title_config import get_title_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TitleMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
title: NotRequired[str | None]
|
||||
|
||||
|
||||
class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
"""Automatically generate a title for the thread after the first user message."""
|
||||
|
||||
state_schema = TitleMiddlewareState
|
||||
|
||||
def _normalize_content(self, content: object) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = [self._normalize_content(item) for item in content]
|
||||
return "\n".join(part for part in parts if part)
|
||||
|
||||
if isinstance(content, dict):
|
||||
text_value = content.get("text")
|
||||
if isinstance(text_value, str):
|
||||
return text_value
|
||||
|
||||
nested_content = content.get("content")
|
||||
if nested_content is not None:
|
||||
return self._normalize_content(nested_content)
|
||||
|
||||
return ""
|
||||
|
||||
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
||||
"""Check if we should generate a title for this thread."""
|
||||
config = get_title_config()
|
||||
if not config.enabled:
|
||||
return False
|
||||
|
||||
# Check if thread already has a title in state
|
||||
if state.get("title"):
|
||||
return False
|
||||
|
||||
# Check if this is the first turn (has at least one user message and one assistant response)
|
||||
messages = state.get("messages", [])
|
||||
if len(messages) < 2:
|
||||
return False
|
||||
|
||||
# Count user and assistant messages
|
||||
user_messages = [m for m in messages if m.type == "human"]
|
||||
assistant_messages = [m for m in messages if m.type == "ai"]
|
||||
|
||||
# Generate title after first complete exchange
|
||||
return len(user_messages) == 1 and len(assistant_messages) >= 1
|
||||
|
||||
def _build_title_prompt(self, state: TitleMiddlewareState) -> tuple[str, str]:
|
||||
"""Extract user/assistant messages and build the title prompt.
|
||||
|
||||
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
|
||||
"""
|
||||
config = get_title_config()
|
||||
messages = state.get("messages", [])
|
||||
|
||||
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
||||
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
|
||||
|
||||
user_msg = self._normalize_content(user_msg_content)
|
||||
assistant_msg = self._normalize_content(assistant_msg_content)
|
||||
|
||||
prompt = config.prompt_template.format(
|
||||
max_words=config.max_words,
|
||||
user_msg=user_msg[:500],
|
||||
assistant_msg=assistant_msg[:500],
|
||||
)
|
||||
return prompt, user_msg
|
||||
|
||||
def _parse_title(self, content: object) -> str:
|
||||
"""Normalize model output into a clean title string."""
|
||||
config = get_title_config()
|
||||
title_content = self._normalize_content(content)
|
||||
title = title_content.strip().strip('"').strip("'")
|
||||
return title[: config.max_chars] if len(title) > config.max_chars else title
|
||||
|
||||
def _fallback_title(self, user_msg: str) -> str:
|
||||
config = get_title_config()
|
||||
fallback_chars = min(config.max_chars, 50)
|
||||
if len(user_msg) > fallback_chars:
|
||||
return user_msg[:fallback_chars].rstrip() + "..."
|
||||
return user_msg if user_msg else "New Conversation"
|
||||
|
||||
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Generate a local fallback title without blocking on an LLM call."""
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
_, user_msg = self._build_title_prompt(state)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Generate a title asynchronously and fall back locally on failure."""
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
config = get_title_config()
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
|
||||
try:
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
response = await model.ainvoke(prompt)
|
||||
title = self._parse_title(response.content)
|
||||
if title:
|
||||
return {"title": title}
|
||||
except Exception:
|
||||
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
|
||||
@override
|
||||
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
return self._generate_title_result(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
return await self._agenerate_title_result(state)
|
||||
@@ -0,0 +1,100 @@
|
||||
"""Middleware that extends TodoListMiddleware with context-loss detection.
|
||||
|
||||
When the message history is truncated (e.g., by SummarizationMiddleware), the
|
||||
original `write_todos` tool call and its ToolMessage can be scrolled out of the
|
||||
active context window. This middleware detects that situation and injects a
|
||||
reminder message so the model still knows about the outstanding todo list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
from langchain.agents.middleware.todo import PlanningState, Todo
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
def _todos_in_messages(messages: list[Any]) -> bool:
|
||||
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
|
||||
for msg in messages:
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("name") == "write_todos":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _reminder_in_messages(messages: list[Any]) -> bool:
|
||||
"""Return True if a todo_reminder HumanMessage is already present in *messages*."""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_reminder":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _format_todos(todos: list[Todo]) -> str:
|
||||
"""Format a list of Todo items into a human-readable string."""
|
||||
lines: list[str] = []
|
||||
for todo in todos:
|
||||
status = todo.get("status", "pending")
|
||||
content = todo.get("content", "")
|
||||
lines.append(f"- [{status}] {content}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class TodoMiddleware(TodoListMiddleware):
|
||||
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
||||
|
||||
When the original `write_todos` tool call has been truncated from the message
|
||||
history (e.g., after summarization), the model loses awareness of the current
|
||||
todo list. This middleware detects that gap in `before_model` / `abefore_model`
|
||||
and injects a reminder message so the model can continue tracking progress.
|
||||
"""
|
||||
|
||||
@override
|
||||
def before_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
runtime: Runtime, # noqa: ARG002
|
||||
) -> dict[str, Any] | None:
|
||||
"""Inject a todo-list reminder when write_todos has left the context window."""
|
||||
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
|
||||
if not todos:
|
||||
return None
|
||||
|
||||
messages = state.get("messages") or []
|
||||
if _todos_in_messages(messages):
|
||||
# write_todos is still visible in context — nothing to do.
|
||||
return None
|
||||
|
||||
if _reminder_in_messages(messages):
|
||||
# A reminder was already injected and hasn't been truncated yet.
|
||||
return None
|
||||
|
||||
# The todo list exists in state but the original write_todos call is gone.
|
||||
# Inject a reminder as a HumanMessage so the model stays aware.
|
||||
formatted = _format_todos(todos)
|
||||
reminder = HumanMessage(
|
||||
name="todo_reminder",
|
||||
content=(
|
||||
"<system_reminder>\n"
|
||||
"Your todo list from earlier is no longer visible in the current context window, "
|
||||
"but it is still active. Here is the current state:\n\n"
|
||||
f"{formatted}\n\n"
|
||||
"Continue tracking and updating this todo list as you work. "
|
||||
"Call `write_todos` whenever the status of any item changes.\n"
|
||||
"</system_reminder>"
|
||||
),
|
||||
)
|
||||
return {"messages": [reminder]}
|
||||
|
||||
@override
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async version of before_model."""
|
||||
return self.before_model(state, runtime)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Middleware for logging LLM token usage."""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenUsageMiddleware(AgentMiddleware):
|
||||
"""Logs token usage from model response usage_metadata."""
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._log_usage(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._log_usage(state)
|
||||
|
||||
def _log_usage(self, state: AgentState) -> None:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1]
|
||||
usage = getattr(last, "usage_metadata", None)
|
||||
if usage:
|
||||
logger.info(
|
||||
"LLM token usage: input=%s output=%s total=%s",
|
||||
usage.get("input_tokens", "?"),
|
||||
usage.get("output_tokens", "?"),
|
||||
usage.get("total_tokens", "?"),
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,143 @@
|
||||
"""Tool error handling middleware and shared runtime middleware builders."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
||||
|
||||
|
||||
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Convert tool exceptions into error ToolMessages so the run can continue."""
|
||||
|
||||
def _build_error_message(self, request: ToolCallRequest, exc: Exception) -> ToolMessage:
|
||||
tool_name = str(request.tool_call.get("name") or "unknown_tool")
|
||||
tool_call_id = str(request.tool_call.get("id") or _MISSING_TOOL_CALL_ID)
|
||||
detail = str(exc).strip() or exc.__class__.__name__
|
||||
if len(detail) > 500:
|
||||
detail = detail[:497] + "..."
|
||||
|
||||
content = f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. Continue with available context, or choose an alternative tool."
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
try:
|
||||
return handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
||||
return self._build_error_message(request, exc)
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
try:
|
||||
return await handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
||||
return self._build_error_message(request, exc)
|
||||
|
||||
|
||||
def _build_runtime_middlewares(
|
||||
*,
|
||||
include_uploads: bool,
|
||||
include_dangling_tool_call_patch: bool,
|
||||
lazy_init: bool = True,
|
||||
) -> list[AgentMiddleware]:
|
||||
"""Build shared base middlewares for agent execution."""
|
||||
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
|
||||
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from deerflow.sandbox.middleware import SandboxMiddleware
|
||||
|
||||
middlewares: list[AgentMiddleware] = [
|
||||
ThreadDataMiddleware(lazy_init=lazy_init),
|
||||
SandboxMiddleware(lazy_init=lazy_init),
|
||||
]
|
||||
|
||||
if include_uploads:
|
||||
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
|
||||
middlewares.insert(1, UploadsMiddleware())
|
||||
|
||||
if include_dangling_tool_call_patch:
|
||||
from deerflow.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
|
||||
|
||||
middlewares.append(DanglingToolCallMiddleware())
|
||||
|
||||
middlewares.append(LLMErrorHandlingMiddleware())
|
||||
|
||||
# Guardrail middleware (if configured)
|
||||
from deerflow.config.guardrails_config import get_guardrails_config
|
||||
|
||||
guardrails_config = get_guardrails_config()
|
||||
if guardrails_config.enabled and guardrails_config.provider:
|
||||
import inspect
|
||||
|
||||
from deerflow.guardrails.middleware import GuardrailMiddleware
|
||||
from deerflow.reflection import resolve_variable
|
||||
|
||||
provider_cls = resolve_variable(guardrails_config.provider.use)
|
||||
provider_kwargs = dict(guardrails_config.provider.config) if guardrails_config.provider.config else {}
|
||||
# Pass framework hint if the provider accepts it (e.g. for config discovery).
|
||||
# Built-in providers like AllowlistProvider don't need it, so only inject
|
||||
# when the constructor accepts 'framework' or '**kwargs'.
|
||||
if "framework" not in provider_kwargs:
|
||||
try:
|
||||
sig = inspect.signature(provider_cls.__init__)
|
||||
if "framework" in sig.parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):
|
||||
provider_kwargs["framework"] = "deerflow"
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
provider = provider_cls(**provider_kwargs)
|
||||
middlewares.append(GuardrailMiddleware(provider, fail_closed=guardrails_config.fail_closed, passport=guardrails_config.passport))
|
||||
|
||||
from deerflow.agents.middlewares.sandbox_audit_middleware import SandboxAuditMiddleware
|
||||
|
||||
middlewares.append(SandboxAuditMiddleware())
|
||||
middlewares.append(ToolErrorHandlingMiddleware())
|
||||
return middlewares
|
||||
|
||||
|
||||
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
include_uploads=True,
|
||||
include_dangling_tool_call_patch=True,
|
||||
lazy_init=lazy_init,
|
||||
)
|
||||
|
||||
|
||||
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
include_uploads=False,
|
||||
include_dangling_tool_call_patch=True,
|
||||
lazy_init=lazy_init,
|
||||
)
|
||||
@@ -0,0 +1,293 @@
|
||||
"""Middleware to inject uploaded files information into agent context."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.utils.file_conversion import extract_outline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_OUTLINE_PREVIEW_LINES = 5
|
||||
|
||||
|
||||
def _extract_outline_for_file(file_path: Path) -> tuple[list[dict], list[str]]:
|
||||
"""Return the document outline and fallback preview for *file_path*.
|
||||
|
||||
Looks for a sibling ``<stem>.md`` file produced by the upload conversion
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
(outline, preview) where:
|
||||
- outline: list of ``{title, line}`` dicts (plus optional sentinel).
|
||||
Empty when no headings are found or no .md exists.
|
||||
- preview: first few non-empty lines of the .md, used as a content
|
||||
anchor when outline is empty so the agent has some context.
|
||||
Empty when outline is non-empty (no fallback needed).
|
||||
"""
|
||||
md_path = file_path.with_suffix(".md")
|
||||
if not md_path.is_file():
|
||||
return [], []
|
||||
|
||||
outline = extract_outline(md_path)
|
||||
if outline:
|
||||
logger.debug("Extracted %d outline entries from %s", len(outline), file_path.name)
|
||||
return outline, []
|
||||
|
||||
# outline is empty — read the first few non-empty lines as a content preview
|
||||
preview: list[str] = []
|
||||
try:
|
||||
with md_path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
stripped = line.strip()
|
||||
if stripped:
|
||||
preview.append(stripped)
|
||||
if len(preview) >= _OUTLINE_PREVIEW_LINES:
|
||||
break
|
||||
except Exception:
|
||||
logger.debug("Failed to read preview lines from %s", md_path, exc_info=True)
|
||||
return [], preview
|
||||
|
||||
|
||||
class UploadsMiddlewareState(AgentState):
|
||||
"""State schema for uploads middleware."""
|
||||
|
||||
uploaded_files: NotRequired[list[dict] | None]
|
||||
|
||||
|
||||
class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
"""Middleware to inject uploaded files information into the agent context.
|
||||
|
||||
Reads file metadata from the current message's additional_kwargs.files
|
||||
(set by the frontend after upload) and prepends an <uploaded_files> block
|
||||
to the last human message so the model knows which files are available.
|
||||
"""
|
||||
|
||||
state_schema = UploadsMiddlewareState
|
||||
|
||||
def __init__(self, base_dir: str | None = None):
|
||||
"""Initialize the middleware.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for thread data. Defaults to Paths resolution.
|
||||
"""
|
||||
super().__init__()
|
||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||
|
||||
def _format_file_entry(self, file: dict, lines: list[str]) -> None:
|
||||
"""Append a single file entry (name, size, path, optional outline) to lines."""
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
outline = file.get("outline") or []
|
||||
if outline:
|
||||
truncated = outline[-1].get("truncated", False)
|
||||
visible = [e for e in outline if not e.get("truncated")]
|
||||
lines.append(" Document outline (use `read_file` with line ranges to read sections):")
|
||||
for entry in visible:
|
||||
lines.append(f" L{entry['line']}: {entry['title']}")
|
||||
if truncated:
|
||||
lines.append(f" ... (showing first {len(visible)} headings; use `read_file` to explore further)")
|
||||
else:
|
||||
preview = file.get("outline_preview") or []
|
||||
if preview:
|
||||
lines.append(" No structural headings detected. Document begins with:")
|
||||
for text in preview:
|
||||
lines.append(f" > {text}")
|
||||
lines.append(" Use `grep` to search for keywords (e.g. `grep(pattern='keyword', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("")
|
||||
|
||||
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
|
||||
"""Create a formatted message listing uploaded files.
|
||||
|
||||
Args:
|
||||
new_files: Files uploaded in the current message.
|
||||
historical_files: Files uploaded in previous messages.
|
||||
Each file dict may contain an optional ``outline`` key — a list of
|
||||
``{title, line}`` dicts extracted from the converted Markdown file.
|
||||
|
||||
Returns:
|
||||
Formatted string inside <uploaded_files> tags.
|
||||
"""
|
||||
lines = ["<uploaded_files>"]
|
||||
|
||||
lines.append("The following files were uploaded in this message:")
|
||||
lines.append("")
|
||||
if new_files:
|
||||
for file in new_files:
|
||||
self._format_file_entry(file, lines)
|
||||
else:
|
||||
lines.append("(empty)")
|
||||
lines.append("")
|
||||
|
||||
if historical_files:
|
||||
lines.append("The following files were uploaded in previous messages and are still available:")
|
||||
lines.append("")
|
||||
for file in historical_files:
|
||||
self._format_file_entry(file, lines)
|
||||
|
||||
lines.append("To work with these files:")
|
||||
lines.append("- Read from the file first — use the outline line numbers and `read_file` to locate relevant sections.")
|
||||
lines.append("- Use `grep` to search for keywords when you are not sure which section to look at")
|
||||
lines.append(" (e.g. `grep(pattern='revenue', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("- Use `glob` to find files by name pattern")
|
||||
lines.append(" (e.g. `glob(pattern='**/*.md', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("- Only fall back to web search if the file content is clearly insufficient to answer the question.")
|
||||
lines.append("</uploaded_files>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _files_from_kwargs(self, message: HumanMessage, uploads_dir: Path | None = None) -> list[dict] | None:
|
||||
"""Extract file info from message additional_kwargs.files.
|
||||
|
||||
The frontend sends uploaded file metadata in additional_kwargs.files
|
||||
after a successful upload. Each entry has: filename, size (bytes),
|
||||
path (virtual path), status.
|
||||
|
||||
Args:
|
||||
message: The human message to inspect.
|
||||
uploads_dir: Physical uploads directory used to verify file existence.
|
||||
When provided, entries whose files no longer exist are skipped.
|
||||
|
||||
Returns:
|
||||
List of file dicts with virtual paths, or None if the field is absent or empty.
|
||||
"""
|
||||
kwargs_files = (message.additional_kwargs or {}).get("files")
|
||||
if not isinstance(kwargs_files, list) or not kwargs_files:
|
||||
return None
|
||||
|
||||
files = []
|
||||
for f in kwargs_files:
|
||||
if not isinstance(f, dict):
|
||||
continue
|
||||
filename = f.get("filename") or ""
|
||||
if not filename or Path(filename).name != filename:
|
||||
continue
|
||||
if uploads_dir is not None and not (uploads_dir / filename).is_file():
|
||||
continue
|
||||
files.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"size": int(f.get("size") or 0),
|
||||
"path": f"/mnt/user-data/uploads/{filename}",
|
||||
"extension": Path(filename).suffix,
|
||||
}
|
||||
)
|
||||
return files if files else None
|
||||
|
||||
@override
|
||||
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Inject uploaded files information before agent execution.
|
||||
|
||||
New files come from the current message's additional_kwargs.files.
|
||||
Historical files are scanned from the thread's uploads directory,
|
||||
excluding the new ones.
|
||||
|
||||
Prepends <uploaded_files> context to the last human message content.
|
||||
The original additional_kwargs (including files metadata) is preserved
|
||||
on the updated message so the frontend can read it from the stream.
|
||||
|
||||
Args:
|
||||
state: Current agent state.
|
||||
runtime: Runtime context containing thread_id.
|
||||
|
||||
Returns:
|
||||
State updates including uploaded files list.
|
||||
"""
|
||||
messages = list(state.get("messages", []))
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message_index = len(messages) - 1
|
||||
last_message = messages[last_message_index]
|
||||
|
||||
if not isinstance(last_message, HumanMessage):
|
||||
return None
|
||||
|
||||
# Resolve uploads directory for existence checks
|
||||
thread_id = (runtime.context or {}).get("thread_id")
|
||||
if thread_id is None:
|
||||
try:
|
||||
from langgraph.config import get_config
|
||||
|
||||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
||||
except RuntimeError:
|
||||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||||
|
||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
|
||||
|
||||
# Collect historical files from the uploads directory (all except the new ones)
|
||||
new_filenames = {f["filename"] for f in new_files}
|
||||
historical_files: list[dict] = []
|
||||
if uploads_dir and uploads_dir.exists():
|
||||
for file_path in sorted(uploads_dir.iterdir()):
|
||||
if file_path.is_file() and file_path.name not in new_filenames:
|
||||
stat = file_path.stat()
|
||||
outline, preview = _extract_outline_for_file(file_path)
|
||||
historical_files.append(
|
||||
{
|
||||
"filename": file_path.name,
|
||||
"size": stat.st_size,
|
||||
"path": f"/mnt/user-data/uploads/{file_path.name}",
|
||||
"extension": file_path.suffix,
|
||||
"outline": outline,
|
||||
"outline_preview": preview,
|
||||
}
|
||||
)
|
||||
|
||||
# Attach outlines to new files as well
|
||||
if uploads_dir:
|
||||
for file in new_files:
|
||||
phys_path = uploads_dir / file["filename"]
|
||||
outline, preview = _extract_outline_for_file(phys_path)
|
||||
file["outline"] = outline
|
||||
file["outline_preview"] = preview
|
||||
|
||||
if not new_files and not historical_files:
|
||||
return None
|
||||
|
||||
logger.debug(f"New files: {[f['filename'] for f in new_files]}, historical: {[f['filename'] for f in historical_files]}")
|
||||
|
||||
# Create files message and prepend to the last human message content
|
||||
files_message = self._create_files_message(new_files, historical_files)
|
||||
|
||||
# Extract original content - handle both string and list formats
|
||||
original_content = last_message.content
|
||||
if isinstance(original_content, str):
|
||||
# Simple case: string content, just prepend files message
|
||||
updated_content = f"{files_message}\n\n{original_content}"
|
||||
elif isinstance(original_content, list):
|
||||
# Complex case: list content (multimodal), preserve all blocks
|
||||
# Prepend files message as the first text block
|
||||
files_block = {"type": "text", "text": f"{files_message}\n\n"}
|
||||
# Keep all original blocks (including images)
|
||||
updated_content = [files_block, *original_content]
|
||||
else:
|
||||
# Other types, preserve as-is
|
||||
updated_content = original_content
|
||||
|
||||
# Create new message with combined content.
|
||||
# Preserve additional_kwargs (including files metadata) so the frontend
|
||||
# can read structured file info from the streamed message.
|
||||
updated_message = HumanMessage(
|
||||
content=updated_content,
|
||||
id=last_message.id,
|
||||
additional_kwargs=last_message.additional_kwargs,
|
||||
)
|
||||
|
||||
messages[last_message_index] = updated_message
|
||||
|
||||
return {
|
||||
"uploaded_files": new_files,
|
||||
"messages": messages,
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
"""Middleware for injecting image details into conversation before LLM call."""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ViewImageMiddlewareState(ThreadState):
|
||||
"""Reuse the thread state so reducer-backed keys keep their annotations."""
|
||||
|
||||
|
||||
class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]):
|
||||
"""Injects image details as a human message before LLM calls when view_image tools have completed.
|
||||
|
||||
This middleware:
|
||||
1. Runs before each LLM call
|
||||
2. Checks if the last assistant message contains view_image tool calls
|
||||
3. Verifies all tool calls in that message have been completed (have corresponding ToolMessages)
|
||||
4. If conditions are met, creates a human message with all viewed image details (including base64 data)
|
||||
5. Adds the message to state so the LLM can see and analyze the images
|
||||
|
||||
This enables the LLM to automatically receive and analyze images that were loaded via view_image tool,
|
||||
without requiring explicit user prompts to describe the images.
|
||||
"""
|
||||
|
||||
state_schema = ViewImageMiddlewareState
|
||||
|
||||
def _get_last_assistant_message(self, messages: list) -> AIMessage | None:
|
||||
"""Get the last assistant message from the message list.
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
|
||||
Returns:
|
||||
Last AIMessage or None if not found
|
||||
"""
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
return msg
|
||||
return None
|
||||
|
||||
def _has_view_image_tool(self, message: AIMessage) -> bool:
|
||||
"""Check if the assistant message contains view_image tool calls.
|
||||
|
||||
Args:
|
||||
message: Assistant message to check
|
||||
|
||||
Returns:
|
||||
True if message contains view_image tool calls
|
||||
"""
|
||||
if not hasattr(message, "tool_calls") or not message.tool_calls:
|
||||
return False
|
||||
|
||||
return any(tool_call.get("name") == "view_image" for tool_call in message.tool_calls)
|
||||
|
||||
def _all_tools_completed(self, messages: list, assistant_msg: AIMessage) -> bool:
|
||||
"""Check if all tool calls in the assistant message have been completed.
|
||||
|
||||
Args:
|
||||
messages: List of all messages
|
||||
assistant_msg: The assistant message containing tool calls
|
||||
|
||||
Returns:
|
||||
True if all tool calls have corresponding ToolMessages
|
||||
"""
|
||||
if not hasattr(assistant_msg, "tool_calls") or not assistant_msg.tool_calls:
|
||||
return False
|
||||
|
||||
# Get all tool call IDs from the assistant message
|
||||
tool_call_ids = {tool_call.get("id") for tool_call in assistant_msg.tool_calls if tool_call.get("id")}
|
||||
|
||||
# Find the index of the assistant message
|
||||
try:
|
||||
assistant_idx = messages.index(assistant_msg)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Get all ToolMessages after the assistant message
|
||||
completed_tool_ids = set()
|
||||
for msg in messages[assistant_idx + 1 :]:
|
||||
if isinstance(msg, ToolMessage) and msg.tool_call_id:
|
||||
completed_tool_ids.add(msg.tool_call_id)
|
||||
|
||||
# Check if all tool calls have been completed
|
||||
return tool_call_ids.issubset(completed_tool_ids)
|
||||
|
||||
def _create_image_details_message(self, state: ViewImageMiddlewareState) -> list[str | dict]:
|
||||
"""Create a formatted message with all viewed image details.
|
||||
|
||||
Args:
|
||||
state: Current state containing viewed_images
|
||||
|
||||
Returns:
|
||||
List of content blocks (text and images) for the HumanMessage
|
||||
"""
|
||||
viewed_images = state.get("viewed_images", {})
|
||||
if not viewed_images:
|
||||
# Return a properly formatted text block, not a plain string array
|
||||
return [{"type": "text", "text": "No images have been viewed."}]
|
||||
|
||||
# Build the message with image information
|
||||
content_blocks: list[str | dict] = [{"type": "text", "text": "Here are the images you've viewed:"}]
|
||||
|
||||
for image_path, image_data in viewed_images.items():
|
||||
mime_type = image_data.get("mime_type", "unknown")
|
||||
base64_data = image_data.get("base64", "")
|
||||
|
||||
# Add text description
|
||||
content_blocks.append({"type": "text", "text": f"\n- **{image_path}** ({mime_type})"})
|
||||
|
||||
# Add the actual image data so LLM can "see" it
|
||||
if base64_data:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{mime_type};base64,{base64_data}"},
|
||||
}
|
||||
)
|
||||
|
||||
return content_blocks
|
||||
|
||||
def _should_inject_image_message(self, state: ViewImageMiddlewareState) -> bool:
|
||||
"""Determine if we should inject an image details message.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
|
||||
Returns:
|
||||
True if we should inject the message
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
# Get the last assistant message
|
||||
last_assistant_msg = self._get_last_assistant_message(messages)
|
||||
if not last_assistant_msg:
|
||||
return False
|
||||
|
||||
# Check if it has view_image tool calls
|
||||
if not self._has_view_image_tool(last_assistant_msg):
|
||||
return False
|
||||
|
||||
# Check if all tools have been completed
|
||||
if not self._all_tools_completed(messages, last_assistant_msg):
|
||||
return False
|
||||
|
||||
# Check if we've already added an image details message
|
||||
# Look for a human message after the last assistant message that contains image details
|
||||
assistant_idx = messages.index(last_assistant_msg)
|
||||
for msg in messages[assistant_idx + 1 :]:
|
||||
if isinstance(msg, HumanMessage):
|
||||
content_str = str(msg.content)
|
||||
if "Here are the images you've viewed" in content_str or "Here are the details of the images you've viewed" in content_str:
|
||||
# Already added, don't add again
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _inject_image_message(self, state: ViewImageMiddlewareState) -> dict | None:
|
||||
"""Internal helper to inject image details message.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
if not self._should_inject_image_message(state):
|
||||
return None
|
||||
|
||||
# Create the image details message with text and image content
|
||||
image_content = self._create_image_details_message(state)
|
||||
|
||||
# Create a new human message with mixed content (text + images)
|
||||
human_msg = HumanMessage(content=image_content)
|
||||
|
||||
logger.debug("Injecting image details message with images before LLM call")
|
||||
|
||||
# Return state update with the new message
|
||||
return {"messages": [human_msg]}
|
||||
|
||||
@override
|
||||
def before_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Inject image details message before LLM call if view_image tools have completed (sync version).
|
||||
|
||||
This runs before each LLM call, checking if the previous turn included view_image
|
||||
tool calls that have all completed. If so, it injects a human message with the image
|
||||
details so the LLM can see and analyze the images.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
runtime: Runtime context (unused but required by interface)
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
return self._inject_image_message(state)
|
||||
|
||||
@override
|
||||
async def abefore_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Inject image details message before LLM call if view_image tools have completed (async version).
|
||||
|
||||
This runs before each LLM call, checking if the previous turn included view_image
|
||||
tool calls that have all completed. If so, it injects a human message with the image
|
||||
details so the LLM can see and analyze the images.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
runtime: Runtime context (unused but required by interface)
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
return self._inject_image_message(state)
|
||||
@@ -0,0 +1,55 @@
|
||||
from typing import Annotated, NotRequired, TypedDict
|
||||
|
||||
from langchain.agents import AgentState
|
||||
|
||||
|
||||
class SandboxState(TypedDict):
|
||||
sandbox_id: NotRequired[str | None]
|
||||
|
||||
|
||||
class ThreadDataState(TypedDict):
|
||||
workspace_path: NotRequired[str | None]
|
||||
uploads_path: NotRequired[str | None]
|
||||
outputs_path: NotRequired[str | None]
|
||||
|
||||
|
||||
class ViewedImageData(TypedDict):
|
||||
base64: str
|
||||
mime_type: str
|
||||
|
||||
|
||||
def merge_artifacts(existing: list[str] | None, new: list[str] | None) -> list[str]:
|
||||
"""Reducer for artifacts list - merges and deduplicates artifacts."""
|
||||
if existing is None:
|
||||
return new or []
|
||||
if new is None:
|
||||
return existing
|
||||
# Use dict.fromkeys to deduplicate while preserving order
|
||||
return list(dict.fromkeys(existing + new))
|
||||
|
||||
|
||||
def merge_viewed_images(existing: dict[str, ViewedImageData] | None, new: dict[str, ViewedImageData] | None) -> dict[str, ViewedImageData]:
|
||||
"""Reducer for viewed_images dict - merges image dictionaries.
|
||||
|
||||
Special case: If new is an empty dict {}, it clears the existing images.
|
||||
This allows middlewares to clear the viewed_images state after processing.
|
||||
"""
|
||||
if existing is None:
|
||||
return new or {}
|
||||
if new is None:
|
||||
return existing
|
||||
# Special case: empty dict means clear all viewed images
|
||||
if len(new) == 0:
|
||||
return {}
|
||||
# Merge dictionaries, new values override existing ones for same keys
|
||||
return {**existing, **new}
|
||||
|
||||
|
||||
class ThreadState(AgentState):
|
||||
sandbox: NotRequired[SandboxState | None]
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
title: NotRequired[str | None]
|
||||
artifacts: Annotated[list[str], merge_artifacts]
|
||||
todos: NotRequired[list | None]
|
||||
uploaded_files: NotRequired[list[dict] | None]
|
||||
viewed_images: Annotated[dict[str, ViewedImageData], merge_viewed_images] # image_path -> {base64, mime_type}
|
||||
1195
deer-flow/backend/packages/harness/deerflow/client.py
Normal file
1195
deer-flow/backend/packages/harness/deerflow/client.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,39 @@
|
||||
"""Hard-disable shim for native (unhardened) community web tools.
|
||||
|
||||
In this hardened DeerFlow build, the only allowed web access surface is
|
||||
``deerflow.community.searx.tools`` (web_search, web_fetch, image_search).
|
||||
The legacy providers below have been deliberately stubbed out and will
|
||||
raise on import so that misconfiguration cannot silently fall back to
|
||||
unsanitized output:
|
||||
|
||||
- ddg_search (DuckDuckGo)
|
||||
- tavily (Tavily)
|
||||
- exa (Exa)
|
||||
- firecrawl (Firecrawl)
|
||||
- jina_ai (Jina Reader)
|
||||
- infoquest (InfoQuest)
|
||||
- image_search (DDG image fallback)
|
||||
|
||||
If you really need one of these back, undo the change in the matching
|
||||
``community/<name>/tools.py`` and audit the call site for prompt-injection
|
||||
hardening first.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class NativeWebToolDisabledError(RuntimeError):
|
||||
"""Raised when a hard-disabled native web tool is imported or invoked."""
|
||||
|
||||
|
||||
_MESSAGE_TEMPLATE = (
|
||||
"Native web tool '{provider}' is disabled in this hardened DeerFlow build. "
|
||||
"Use 'deerflow.community.searx.tools' (web_search / web_fetch / image_search) instead. "
|
||||
"If you really need '{provider}', re-enable it in "
|
||||
"deerflow/community/{provider}/tools.py and harden it first."
|
||||
)
|
||||
|
||||
|
||||
def reject_native_provider(provider: str) -> None:
|
||||
"""Raise a clear error pointing the operator at the searx replacement."""
|
||||
raise NativeWebToolDisabledError(_MESSAGE_TEMPLATE.format(provider=provider))
|
||||
@@ -0,0 +1,15 @@
|
||||
from .aio_sandbox import AioSandbox
|
||||
from .aio_sandbox_provider import AioSandboxProvider
|
||||
from .backend import SandboxBackend
|
||||
from .local_backend import LocalContainerBackend
|
||||
from .remote_backend import RemoteSandboxBackend
|
||||
from .sandbox_info import SandboxInfo
|
||||
|
||||
__all__ = [
|
||||
"AioSandbox",
|
||||
"AioSandboxProvider",
|
||||
"LocalContainerBackend",
|
||||
"RemoteSandboxBackend",
|
||||
"SandboxBackend",
|
||||
"SandboxInfo",
|
||||
]
|
||||
@@ -0,0 +1,232 @@
|
||||
import base64
|
||||
import logging
|
||||
import shlex
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from agent_sandbox import Sandbox as AioSandboxClient
|
||||
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.search import GrepMatch, path_matches, should_ignore_path, truncate_line
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ERROR_OBSERVATION_SIGNATURE = "'ErrorObservation' object has no attribute 'exit_code'"
|
||||
|
||||
|
||||
class AioSandbox(Sandbox):
|
||||
"""Sandbox implementation using the agent-infra/sandbox Docker container.
|
||||
|
||||
This sandbox connects to a running AIO sandbox container via HTTP API.
|
||||
A threading lock serializes shell commands to prevent concurrent requests
|
||||
from corrupting the container's single persistent session (see #1433).
|
||||
"""
|
||||
|
||||
def __init__(self, id: str, base_url: str, home_dir: str | None = None):
|
||||
"""Initialize the AIO sandbox.
|
||||
|
||||
Args:
|
||||
id: Unique identifier for this sandbox instance.
|
||||
base_url: URL of the sandbox API (e.g., http://localhost:8080).
|
||||
home_dir: Home directory inside the sandbox. If None, will be fetched from the sandbox.
|
||||
"""
|
||||
super().__init__(id)
|
||||
self._base_url = base_url
|
||||
self._client = AioSandboxClient(base_url=base_url, timeout=600)
|
||||
self._home_dir = home_dir
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self._base_url
|
||||
|
||||
@property
|
||||
def home_dir(self) -> str:
|
||||
"""Get the home directory inside the sandbox."""
|
||||
if self._home_dir is None:
|
||||
context = self._client.sandbox.get_context()
|
||||
self._home_dir = context.home_dir
|
||||
return self._home_dir
|
||||
|
||||
def execute_command(self, command: str) -> str:
|
||||
"""Execute a shell command in the sandbox.
|
||||
|
||||
Uses a lock to serialize concurrent requests. The AIO sandbox
|
||||
container maintains a single persistent shell session that
|
||||
corrupts when hit with concurrent exec_command calls (returns
|
||||
``ErrorObservation`` instead of real output). If corruption is
|
||||
detected despite the lock (e.g. multiple processes sharing a
|
||||
sandbox), the command is retried on a fresh session.
|
||||
|
||||
Args:
|
||||
command: The command to execute.
|
||||
|
||||
Returns:
|
||||
The output of the command.
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
result = self._client.shell.exec_command(command=command)
|
||||
output = result.data.output if result.data else ""
|
||||
|
||||
if output and _ERROR_OBSERVATION_SIGNATURE in output:
|
||||
logger.warning("ErrorObservation detected in sandbox output, retrying with a fresh session")
|
||||
fresh_id = str(uuid.uuid4())
|
||||
result = self._client.shell.exec_command(command=command, id=fresh_id)
|
||||
output = result.data.output if result.data else ""
|
||||
|
||||
return output if output else "(no output)"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute command in sandbox: {e}")
|
||||
return f"Error: {e}"
|
||||
|
||||
def read_file(self, path: str) -> str:
|
||||
"""Read the content of a file in the sandbox.
|
||||
|
||||
Args:
|
||||
path: The absolute path of the file to read.
|
||||
|
||||
Returns:
|
||||
The content of the file.
|
||||
"""
|
||||
try:
|
||||
result = self._client.file.read_file(file=path)
|
||||
return result.data.content if result.data else ""
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read file in sandbox: {e}")
|
||||
return f"Error: {e}"
|
||||
|
||||
def list_dir(self, path: str, max_depth: int = 2) -> list[str]:
|
||||
"""List the contents of a directory in the sandbox.
|
||||
|
||||
Args:
|
||||
path: The absolute path of the directory to list.
|
||||
max_depth: The maximum depth to traverse. Default is 2.
|
||||
|
||||
Returns:
|
||||
The contents of the directory.
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
result = self._client.shell.exec_command(command=f"find {shlex.quote(path)} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500")
|
||||
output = result.data.output if result.data else ""
|
||||
if output:
|
||||
return [line.strip() for line in output.strip().split("\n") if line.strip()]
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list directory in sandbox: {e}")
|
||||
return []
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
"""Write content to a file in the sandbox.
|
||||
|
||||
Args:
|
||||
path: The absolute path of the file to write to.
|
||||
content: The text content to write to the file.
|
||||
append: Whether to append the content to the file.
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
if append:
|
||||
existing = self.read_file(path)
|
||||
if not existing.startswith("Error:"):
|
||||
content = existing + content
|
||||
self._client.file.write_file(file=path, content=content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write file in sandbox: {e}")
|
||||
raise
|
||||
|
||||
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
if not include_dirs:
|
||||
result = self._client.file.find_files(path=path, glob=pattern)
|
||||
files = result.data.files if result.data and result.data.files else []
|
||||
filtered = [file_path for file_path in files if not should_ignore_path(file_path)]
|
||||
truncated = len(filtered) > max_results
|
||||
return filtered[:max_results], truncated
|
||||
|
||||
result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
|
||||
entries = result.data.files if result.data and result.data.files else []
|
||||
matches: list[str] = []
|
||||
root_path = path.rstrip("/") or "/"
|
||||
root_prefix = root_path if root_path == "/" else f"{root_path}/"
|
||||
for entry in entries:
|
||||
if entry.path != root_path and not entry.path.startswith(root_prefix):
|
||||
continue
|
||||
if should_ignore_path(entry.path):
|
||||
continue
|
||||
rel_path = entry.path[len(root_path) :].lstrip("/")
|
||||
if path_matches(pattern, rel_path):
|
||||
matches.append(entry.path)
|
||||
if len(matches) >= max_results:
|
||||
return matches, True
|
||||
return matches, False
|
||||
|
||||
def grep(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str,
|
||||
*,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
import re as _re
|
||||
|
||||
regex_source = _re.escape(pattern) if literal else pattern
|
||||
# Validate the pattern locally so an invalid regex raises re.error
|
||||
# (caught by grep_tool's except re.error handler) rather than a
|
||||
# generic remote API error.
|
||||
_re.compile(regex_source, 0 if case_sensitive else _re.IGNORECASE)
|
||||
regex = regex_source if case_sensitive else f"(?i){regex_source}"
|
||||
|
||||
if glob is not None:
|
||||
find_result = self._client.file.find_files(path=path, glob=glob)
|
||||
candidate_paths = find_result.data.files if find_result.data and find_result.data.files else []
|
||||
else:
|
||||
list_result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
|
||||
entries = list_result.data.files if list_result.data and list_result.data.files else []
|
||||
candidate_paths = [entry.path for entry in entries if not entry.is_directory]
|
||||
|
||||
matches: list[GrepMatch] = []
|
||||
truncated = False
|
||||
|
||||
for file_path in candidate_paths:
|
||||
if should_ignore_path(file_path):
|
||||
continue
|
||||
|
||||
search_result = self._client.file.search_in_file(file=file_path, regex=regex)
|
||||
data = search_result.data
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
line_numbers = data.line_numbers or []
|
||||
matched_lines = data.matches or []
|
||||
for line_number, line in zip(line_numbers, matched_lines):
|
||||
matches.append(
|
||||
GrepMatch(
|
||||
path=file_path,
|
||||
line_number=line_number if isinstance(line_number, int) else 0,
|
||||
line=truncate_line(line),
|
||||
)
|
||||
)
|
||||
if len(matches) >= max_results:
|
||||
truncated = True
|
||||
return matches, truncated
|
||||
|
||||
return matches, truncated
|
||||
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
"""Update a file with binary content in the sandbox.
|
||||
|
||||
Args:
|
||||
path: The absolute path of the file to update.
|
||||
content: The binary content to write to the file.
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
base64_content = base64.b64encode(content).decode("utf-8")
|
||||
self._client.file.write_file(file=path, content=base64_content, encoding="base64")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update file in sandbox: {e}")
|
||||
raise
|
||||
@@ -0,0 +1,694 @@
|
||||
"""AIO Sandbox Provider — orchestrates sandbox lifecycle with pluggable backends.
|
||||
|
||||
This provider composes:
|
||||
- SandboxBackend: how sandboxes are provisioned (local container vs remote/K8s)
|
||||
|
||||
The provider itself handles:
|
||||
- In-process caching for fast repeated access
|
||||
- Idle timeout management
|
||||
- Graceful shutdown with signal handling
|
||||
- Mount computation (thread-specific, skills)
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
except ImportError: # pragma: no cover - Windows fallback
|
||||
fcntl = None # type: ignore[assignment]
|
||||
import msvcrt
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
||||
|
||||
from .aio_sandbox import AioSandbox
|
||||
from .backend import SandboxBackend, wait_for_sandbox_ready
|
||||
from .local_backend import LocalContainerBackend
|
||||
from .remote_backend import RemoteSandboxBackend
|
||||
from .sandbox_info import SandboxInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default configuration
|
||||
DEFAULT_IMAGE = "enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest"
|
||||
DEFAULT_PORT = 8080
|
||||
DEFAULT_CONTAINER_PREFIX = "deer-flow-sandbox"
|
||||
DEFAULT_IDLE_TIMEOUT = 600 # 10 minutes in seconds
|
||||
DEFAULT_REPLICAS = 3 # Maximum concurrent sandbox containers
|
||||
IDLE_CHECK_INTERVAL = 60 # Check every 60 seconds
|
||||
|
||||
|
||||
def _lock_file_exclusive(lock_file) -> None:
|
||||
if fcntl is not None:
|
||||
fcntl.flock(lock_file, fcntl.LOCK_EX)
|
||||
return
|
||||
|
||||
lock_file.seek(0)
|
||||
msvcrt.locking(lock_file.fileno(), msvcrt.LK_LOCK, 1)
|
||||
|
||||
|
||||
def _unlock_file(lock_file) -> None:
|
||||
if fcntl is not None:
|
||||
fcntl.flock(lock_file, fcntl.LOCK_UN)
|
||||
return
|
||||
|
||||
lock_file.seek(0)
|
||||
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
|
||||
|
||||
class AioSandboxProvider(SandboxProvider):
|
||||
"""Sandbox provider that manages containers running the AIO sandbox.
|
||||
|
||||
Architecture:
|
||||
This provider composes a SandboxBackend (how to provision), enabling:
|
||||
- Local Docker/Apple Container mode (auto-start containers)
|
||||
- Remote/K8s mode (connect to pre-existing sandbox URL)
|
||||
|
||||
Configuration options in config.yaml under sandbox:
|
||||
use: deerflow.community.aio_sandbox:AioSandboxProvider
|
||||
image: <container image>
|
||||
port: 8080 # Base port for local containers
|
||||
container_prefix: deer-flow-sandbox
|
||||
idle_timeout: 600 # Idle timeout in seconds (0 to disable)
|
||||
replicas: 3 # Max concurrent sandbox containers (LRU eviction when exceeded)
|
||||
mounts: # Volume mounts for local containers
|
||||
- host_path: /path/on/host
|
||||
container_path: /path/in/container
|
||||
read_only: false
|
||||
environment: # Environment variables for containers
|
||||
NODE_ENV: production
|
||||
API_KEY: $MY_API_KEY
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._sandboxes: dict[str, AioSandbox] = {} # sandbox_id -> AioSandbox instance
|
||||
self._sandbox_infos: dict[str, SandboxInfo] = {} # sandbox_id -> SandboxInfo (for destroy)
|
||||
self._thread_sandboxes: dict[str, str] = {} # thread_id -> sandbox_id
|
||||
self._thread_locks: dict[str, threading.Lock] = {} # thread_id -> in-process lock
|
||||
self._last_activity: dict[str, float] = {} # sandbox_id -> last activity timestamp
|
||||
# Warm pool: released sandboxes whose containers are still running.
|
||||
# Maps sandbox_id -> (SandboxInfo, release_timestamp).
|
||||
# Containers here can be reclaimed quickly (no cold-start) or destroyed
|
||||
# when replicas capacity is exhausted.
|
||||
self._warm_pool: dict[str, tuple[SandboxInfo, float]] = {}
|
||||
self._shutdown_called = False
|
||||
self._idle_checker_stop = threading.Event()
|
||||
self._idle_checker_thread: threading.Thread | None = None
|
||||
|
||||
self._config = self._load_config()
|
||||
self._backend: SandboxBackend = self._create_backend()
|
||||
|
||||
# Register shutdown handler
|
||||
atexit.register(self.shutdown)
|
||||
self._register_signal_handlers()
|
||||
|
||||
# Reconcile orphaned containers from previous process lifecycles
|
||||
self._reconcile_orphans()
|
||||
|
||||
# Start idle checker if enabled
|
||||
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
|
||||
self._start_idle_checker()
|
||||
|
||||
# ── Factory methods ──────────────────────────────────────────────────
|
||||
|
||||
def _create_backend(self) -> SandboxBackend:
|
||||
"""Create the appropriate backend based on configuration.
|
||||
|
||||
Selection logic (checked in order):
|
||||
1. ``provisioner_url`` set → RemoteSandboxBackend (provisioner mode)
|
||||
Provisioner dynamically creates Pods + Services in k3s.
|
||||
2. Default → LocalContainerBackend (local mode)
|
||||
Local provider manages container lifecycle directly (start/stop).
|
||||
"""
|
||||
provisioner_url = self._config.get("provisioner_url")
|
||||
if provisioner_url:
|
||||
logger.info(f"Using remote sandbox backend with provisioner at {provisioner_url}")
|
||||
return RemoteSandboxBackend(provisioner_url=provisioner_url)
|
||||
|
||||
logger.info("Using local container sandbox backend")
|
||||
return LocalContainerBackend(
|
||||
image=self._config["image"],
|
||||
base_port=self._config["port"],
|
||||
container_prefix=self._config["container_prefix"],
|
||||
config_mounts=self._config["mounts"],
|
||||
environment=self._config["environment"],
|
||||
)
|
||||
|
||||
# ── Configuration ────────────────────────────────────────────────────
|
||||
|
||||
def _load_config(self) -> dict:
|
||||
"""Load sandbox configuration from app config."""
|
||||
config = get_app_config()
|
||||
sandbox_config = config.sandbox
|
||||
|
||||
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
|
||||
replicas = getattr(sandbox_config, "replicas", None)
|
||||
|
||||
return {
|
||||
"image": sandbox_config.image or DEFAULT_IMAGE,
|
||||
"port": sandbox_config.port or DEFAULT_PORT,
|
||||
"container_prefix": sandbox_config.container_prefix or DEFAULT_CONTAINER_PREFIX,
|
||||
"idle_timeout": idle_timeout if idle_timeout is not None else DEFAULT_IDLE_TIMEOUT,
|
||||
"replicas": replicas if replicas is not None else DEFAULT_REPLICAS,
|
||||
"mounts": sandbox_config.mounts or [],
|
||||
"environment": self._resolve_env_vars(sandbox_config.environment or {}),
|
||||
# provisioner URL for dynamic pod management (e.g. http://provisioner:8002)
|
||||
"provisioner_url": getattr(sandbox_config, "provisioner_url", None) or "",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _resolve_env_vars(env_config: dict[str, str]) -> dict[str, str]:
|
||||
"""Resolve environment variable references (values starting with $)."""
|
||||
resolved = {}
|
||||
for key, value in env_config.items():
|
||||
if isinstance(value, str) and value.startswith("$"):
|
||||
env_name = value[1:]
|
||||
resolved[key] = os.environ.get(env_name, "")
|
||||
else:
|
||||
resolved[key] = str(value)
|
||||
return resolved
|
||||
|
||||
# ── Startup reconciliation ────────────────────────────────────────────
|
||||
|
||||
def _reconcile_orphans(self) -> None:
|
||||
"""Reconcile orphaned containers left by previous process lifecycles.
|
||||
|
||||
On startup, enumerate all running containers matching our prefix
|
||||
and adopt them all into the warm pool. The idle checker will reclaim
|
||||
containers that nobody re-acquires within ``idle_timeout``.
|
||||
|
||||
All containers are adopted unconditionally because we cannot
|
||||
distinguish "orphaned" from "actively used by another process"
|
||||
based on age alone — ``idle_timeout`` represents inactivity, not
|
||||
uptime. Adopting into the warm pool and letting the idle checker
|
||||
decide avoids destroying containers that a concurrent process may
|
||||
still be using.
|
||||
|
||||
This closes the fundamental gap where in-memory state loss (process
|
||||
restart, crash, SIGKILL) leaves Docker containers running forever.
|
||||
"""
|
||||
try:
|
||||
running = self._backend.list_running()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to enumerate running containers during startup reconciliation: {e}")
|
||||
return
|
||||
|
||||
if not running:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
adopted = 0
|
||||
|
||||
for info in running:
|
||||
age = current_time - info.created_at if info.created_at > 0 else float("inf")
|
||||
# Single lock acquisition per container: atomic check-and-insert.
|
||||
# Avoids a TOCTOU window between the "already tracked?" check and
|
||||
# the warm-pool insert.
|
||||
with self._lock:
|
||||
if info.sandbox_id in self._sandboxes or info.sandbox_id in self._warm_pool:
|
||||
continue
|
||||
self._warm_pool[info.sandbox_id] = (info, current_time)
|
||||
adopted += 1
|
||||
logger.info(f"Adopted container {info.sandbox_id} into warm pool (age: {age:.0f}s)")
|
||||
|
||||
logger.info(f"Startup reconciliation complete: {adopted} adopted into warm pool, {len(running)} total found")
|
||||
|
||||
# ── Deterministic ID ─────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _deterministic_sandbox_id(thread_id: str) -> str:
|
||||
"""Generate a deterministic sandbox ID from a thread ID.
|
||||
|
||||
Ensures all processes derive the same sandbox_id for a given thread,
|
||||
enabling cross-process sandbox discovery without shared memory.
|
||||
"""
|
||||
return hashlib.sha256(thread_id.encode()).hexdigest()[:8]
|
||||
|
||||
# ── Mount helpers ────────────────────────────────────────────────────
|
||||
|
||||
def _get_extra_mounts(self, thread_id: str | None) -> list[tuple[str, str, bool]]:
|
||||
"""Collect all extra mounts for a sandbox (thread-specific + skills)."""
|
||||
mounts: list[tuple[str, str, bool]] = []
|
||||
|
||||
if thread_id:
|
||||
mounts.extend(self._get_thread_mounts(thread_id))
|
||||
logger.info(f"Adding thread mounts for thread {thread_id}: {mounts}")
|
||||
|
||||
skills_mount = self._get_skills_mount()
|
||||
if skills_mount:
|
||||
mounts.append(skills_mount)
|
||||
logger.info(f"Adding skills mount: {skills_mount}")
|
||||
|
||||
return mounts
|
||||
|
||||
@staticmethod
|
||||
def _get_thread_mounts(thread_id: str) -> list[tuple[str, str, bool]]:
|
||||
"""Get volume mounts for a thread's data directories.
|
||||
|
||||
Creates directories if they don't exist (lazy initialization).
|
||||
Mount sources use host_base_dir so that when running inside Docker with a
|
||||
mounted Docker socket (DooD), the host Docker daemon can resolve the paths.
|
||||
"""
|
||||
paths = get_paths()
|
||||
paths.ensure_thread_dirs(thread_id)
|
||||
|
||||
return [
|
||||
(paths.host_sandbox_work_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
|
||||
(paths.host_sandbox_uploads_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
|
||||
(paths.host_sandbox_outputs_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
|
||||
# ACP workspace: read-only inside the sandbox (lead agent reads results;
|
||||
# the ACP subprocess writes from the host side, not from within the container).
|
||||
(paths.host_acp_workspace_dir(thread_id), "/mnt/acp-workspace", True),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _get_skills_mount() -> tuple[str, str, bool] | None:
|
||||
"""Get the skills directory mount configuration.
|
||||
|
||||
Mount source uses DEER_FLOW_HOST_SKILLS_PATH when running inside Docker (DooD)
|
||||
so the host Docker daemon can resolve the path.
|
||||
"""
|
||||
try:
|
||||
config = get_app_config()
|
||||
skills_path = config.skills.get_skills_path()
|
||||
container_path = config.skills.container_path
|
||||
|
||||
if skills_path.exists():
|
||||
# When running inside Docker with DooD, use host-side skills path.
|
||||
host_skills = os.environ.get("DEER_FLOW_HOST_SKILLS_PATH") or str(skills_path)
|
||||
return (host_skills, container_path, True) # Read-only for security
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not setup skills mount: {e}")
|
||||
return None
|
||||
|
||||
# ── Idle timeout management ──────────────────────────────────────────
|
||||
|
||||
def _start_idle_checker(self) -> None:
|
||||
"""Start the background thread that checks for idle sandboxes."""
|
||||
self._idle_checker_thread = threading.Thread(
|
||||
target=self._idle_checker_loop,
|
||||
name="sandbox-idle-checker",
|
||||
daemon=True,
|
||||
)
|
||||
self._idle_checker_thread.start()
|
||||
logger.info(f"Started idle checker thread (timeout: {self._config.get('idle_timeout', DEFAULT_IDLE_TIMEOUT)}s)")
|
||||
|
||||
def _idle_checker_loop(self) -> None:
|
||||
idle_timeout = self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT)
|
||||
while not self._idle_checker_stop.wait(timeout=IDLE_CHECK_INTERVAL):
|
||||
try:
|
||||
self._cleanup_idle_sandboxes(idle_timeout)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in idle checker loop: {e}")
|
||||
|
||||
def _cleanup_idle_sandboxes(self, idle_timeout: float) -> None:
|
||||
current_time = time.time()
|
||||
active_to_destroy = []
|
||||
warm_to_destroy: list[tuple[str, SandboxInfo]] = []
|
||||
|
||||
with self._lock:
|
||||
# Active sandboxes: tracked via _last_activity
|
||||
for sandbox_id, last_activity in self._last_activity.items():
|
||||
idle_duration = current_time - last_activity
|
||||
if idle_duration > idle_timeout:
|
||||
active_to_destroy.append(sandbox_id)
|
||||
logger.info(f"Sandbox {sandbox_id} idle for {idle_duration:.1f}s, marking for destroy")
|
||||
|
||||
# Warm pool: tracked via release_timestamp stored in _warm_pool
|
||||
for sandbox_id, (info, release_ts) in list(self._warm_pool.items()):
|
||||
warm_duration = current_time - release_ts
|
||||
if warm_duration > idle_timeout:
|
||||
warm_to_destroy.append((sandbox_id, info))
|
||||
del self._warm_pool[sandbox_id]
|
||||
logger.info(f"Warm-pool sandbox {sandbox_id} idle for {warm_duration:.1f}s, marking for destroy")
|
||||
|
||||
# Destroy active sandboxes (re-verify still idle before acting)
|
||||
for sandbox_id in active_to_destroy:
|
||||
try:
|
||||
# Re-verify the sandbox is still idle under the lock before destroying.
|
||||
# Between the snapshot above and here, the sandbox may have been
|
||||
# re-acquired (last_activity updated) or already released/destroyed.
|
||||
with self._lock:
|
||||
last_activity = self._last_activity.get(sandbox_id)
|
||||
if last_activity is None:
|
||||
# Already released or destroyed by another path — skip.
|
||||
logger.info(f"Sandbox {sandbox_id} already gone before idle destroy, skipping")
|
||||
continue
|
||||
if (time.time() - last_activity) < idle_timeout:
|
||||
# Re-acquired (activity updated) since the snapshot — skip.
|
||||
logger.info(f"Sandbox {sandbox_id} was re-acquired before idle destroy, skipping")
|
||||
continue
|
||||
logger.info(f"Destroying idle sandbox {sandbox_id}")
|
||||
self.destroy(sandbox_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to destroy idle sandbox {sandbox_id}: {e}")
|
||||
|
||||
# Destroy warm-pool sandboxes (already removed from _warm_pool under lock above)
|
||||
for sandbox_id, info in warm_to_destroy:
|
||||
try:
|
||||
self._backend.destroy(info)
|
||||
logger.info(f"Destroyed idle warm-pool sandbox {sandbox_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to destroy idle warm-pool sandbox {sandbox_id}: {e}")
|
||||
|
||||
# ── Signal handling ──────────────────────────────────────────────────
|
||||
|
||||
def _register_signal_handlers(self) -> None:
|
||||
"""Register signal handlers for graceful shutdown.
|
||||
|
||||
Handles SIGTERM, SIGINT, and SIGHUP (terminal close) to ensure
|
||||
sandbox containers are cleaned up even when the user closes the terminal.
|
||||
"""
|
||||
self._original_sigterm = signal.getsignal(signal.SIGTERM)
|
||||
self._original_sigint = signal.getsignal(signal.SIGINT)
|
||||
self._original_sighup = signal.getsignal(signal.SIGHUP) if hasattr(signal, "SIGHUP") else None
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
self.shutdown()
|
||||
if signum == signal.SIGTERM:
|
||||
original = self._original_sigterm
|
||||
elif hasattr(signal, "SIGHUP") and signum == signal.SIGHUP:
|
||||
original = self._original_sighup
|
||||
else:
|
||||
original = self._original_sigint
|
||||
if callable(original):
|
||||
original(signum, frame)
|
||||
elif original == signal.SIG_DFL:
|
||||
signal.signal(signum, signal.SIG_DFL)
|
||||
signal.raise_signal(signum)
|
||||
|
||||
try:
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, signal_handler)
|
||||
except ValueError:
|
||||
logger.debug("Could not register signal handlers (not main thread)")
|
||||
|
||||
# ── Thread locking (in-process) ──────────────────────────────────────
|
||||
|
||||
def _get_thread_lock(self, thread_id: str) -> threading.Lock:
|
||||
"""Get or create an in-process lock for a specific thread_id."""
|
||||
with self._lock:
|
||||
if thread_id not in self._thread_locks:
|
||||
self._thread_locks[thread_id] = threading.Lock()
|
||||
return self._thread_locks[thread_id]
|
||||
|
||||
# ── Core: acquire / get / release / shutdown ─────────────────────────
|
||||
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
"""Acquire a sandbox environment and return its ID.
|
||||
|
||||
For the same thread_id, this method will return the same sandbox_id
|
||||
across multiple turns, multiple processes, and (with shared storage)
|
||||
multiple pods.
|
||||
|
||||
Thread-safe with both in-process and cross-process locking.
|
||||
|
||||
Args:
|
||||
thread_id: Optional thread ID for thread-specific configurations.
|
||||
|
||||
Returns:
|
||||
The ID of the acquired sandbox environment.
|
||||
"""
|
||||
if thread_id:
|
||||
thread_lock = self._get_thread_lock(thread_id)
|
||||
with thread_lock:
|
||||
return self._acquire_internal(thread_id)
|
||||
else:
|
||||
return self._acquire_internal(thread_id)
|
||||
|
||||
def _acquire_internal(self, thread_id: str | None) -> str:
|
||||
"""Internal sandbox acquisition with two-layer consistency.
|
||||
|
||||
Layer 1: In-process cache (fastest, covers same-process repeated access)
|
||||
Layer 2: Backend discovery (covers containers started by other processes;
|
||||
sandbox_id is deterministic from thread_id so no shared state file
|
||||
is needed — any process can derive the same container name)
|
||||
"""
|
||||
# ── Layer 1: In-process cache (fast path) ──
|
||||
if thread_id:
|
||||
with self._lock:
|
||||
if thread_id in self._thread_sandboxes:
|
||||
existing_id = self._thread_sandboxes[thread_id]
|
||||
if existing_id in self._sandboxes:
|
||||
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}")
|
||||
self._last_activity[existing_id] = time.time()
|
||||
return existing_id
|
||||
else:
|
||||
del self._thread_sandboxes[thread_id]
|
||||
|
||||
# Deterministic ID for thread-specific, random for anonymous
|
||||
sandbox_id = self._deterministic_sandbox_id(thread_id) if thread_id else str(uuid.uuid4())[:8]
|
||||
|
||||
# ── Layer 1.5: Warm pool (container still running, no cold-start) ──
|
||||
if thread_id:
|
||||
with self._lock:
|
||||
if sandbox_id in self._warm_pool:
|
||||
info, _ = self._warm_pool.pop(sandbox_id)
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._sandbox_infos[sandbox_id] = info
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
self._thread_sandboxes[thread_id] = sandbox_id
|
||||
logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
|
||||
return sandbox_id
|
||||
|
||||
# ── Layer 2: Backend discovery + create (protected by cross-process lock) ──
|
||||
# Use a file lock so that two processes racing to create the same sandbox
|
||||
# for the same thread_id serialize here: the second process will discover
|
||||
# the container started by the first instead of hitting a name-conflict.
|
||||
if thread_id:
|
||||
return self._discover_or_create_with_lock(thread_id, sandbox_id)
|
||||
|
||||
return self._create_sandbox(thread_id, sandbox_id)
|
||||
|
||||
def _discover_or_create_with_lock(self, thread_id: str, sandbox_id: str) -> str:
|
||||
"""Discover an existing sandbox or create a new one under a cross-process file lock.
|
||||
|
||||
The file lock serializes concurrent sandbox creation for the same thread_id
|
||||
across multiple processes, preventing container-name conflicts.
|
||||
"""
|
||||
paths = get_paths()
|
||||
paths.ensure_thread_dirs(thread_id)
|
||||
lock_path = paths.thread_dir(thread_id) / f"{sandbox_id}.lock"
|
||||
|
||||
with open(lock_path, "a", encoding="utf-8") as lock_file:
|
||||
locked = False
|
||||
try:
|
||||
_lock_file_exclusive(lock_file)
|
||||
locked = True
|
||||
# Re-check in-process caches under the file lock in case another
|
||||
# thread in this process won the race while we were waiting.
|
||||
with self._lock:
|
||||
if thread_id in self._thread_sandboxes:
|
||||
existing_id = self._thread_sandboxes[thread_id]
|
||||
if existing_id in self._sandboxes:
|
||||
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id} (post-lock check)")
|
||||
self._last_activity[existing_id] = time.time()
|
||||
return existing_id
|
||||
if sandbox_id in self._warm_pool:
|
||||
info, _ = self._warm_pool.pop(sandbox_id)
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._sandbox_infos[sandbox_id] = info
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
self._thread_sandboxes[thread_id] = sandbox_id
|
||||
logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} (post-lock check)")
|
||||
return sandbox_id
|
||||
|
||||
# Backend discovery: another process may have created the container.
|
||||
discovered = self._backend.discover(sandbox_id)
|
||||
if discovered is not None:
|
||||
sandbox = AioSandbox(id=discovered.sandbox_id, base_url=discovered.sandbox_url)
|
||||
with self._lock:
|
||||
self._sandboxes[discovered.sandbox_id] = sandbox
|
||||
self._sandbox_infos[discovered.sandbox_id] = discovered
|
||||
self._last_activity[discovered.sandbox_id] = time.time()
|
||||
self._thread_sandboxes[thread_id] = discovered.sandbox_id
|
||||
logger.info(f"Discovered existing sandbox {discovered.sandbox_id} for thread {thread_id} at {discovered.sandbox_url}")
|
||||
return discovered.sandbox_id
|
||||
|
||||
return self._create_sandbox(thread_id, sandbox_id)
|
||||
finally:
|
||||
if locked:
|
||||
_unlock_file(lock_file)
|
||||
|
||||
def _evict_oldest_warm(self) -> str | None:
|
||||
"""Destroy the oldest container in the warm pool to free capacity.
|
||||
|
||||
Returns:
|
||||
The evicted sandbox_id, or None if warm pool is empty.
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._warm_pool:
|
||||
return None
|
||||
oldest_id = min(self._warm_pool, key=lambda sid: self._warm_pool[sid][1])
|
||||
info, _ = self._warm_pool.pop(oldest_id)
|
||||
|
||||
try:
|
||||
self._backend.destroy(info)
|
||||
logger.info(f"Destroyed warm-pool sandbox {oldest_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to destroy warm-pool sandbox {oldest_id}: {e}")
|
||||
return None
|
||||
return oldest_id
|
||||
|
||||
def _create_sandbox(self, thread_id: str | None, sandbox_id: str) -> str:
|
||||
"""Create a new sandbox via the backend.
|
||||
|
||||
Args:
|
||||
thread_id: Optional thread ID.
|
||||
sandbox_id: The sandbox ID to use.
|
||||
|
||||
Returns:
|
||||
The sandbox_id.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sandbox creation or readiness check fails.
|
||||
"""
|
||||
extra_mounts = self._get_extra_mounts(thread_id)
|
||||
|
||||
# Enforce replicas: only warm-pool containers count toward eviction budget.
|
||||
# Active sandboxes are in use by live threads and must not be forcibly stopped.
|
||||
replicas = self._config.get("replicas", DEFAULT_REPLICAS)
|
||||
with self._lock:
|
||||
total = len(self._sandboxes) + len(self._warm_pool)
|
||||
if total >= replicas:
|
||||
evicted = self._evict_oldest_warm()
|
||||
if evicted:
|
||||
logger.info(f"Evicted warm-pool sandbox {evicted} to stay within replicas={replicas}")
|
||||
else:
|
||||
# All slots are occupied by active sandboxes — proceed anyway and log.
|
||||
# The replicas limit is a soft cap; we never forcibly stop a container
|
||||
# that is actively serving a thread.
|
||||
logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit")
|
||||
|
||||
info = self._backend.create(thread_id, sandbox_id, extra_mounts=extra_mounts or None)
|
||||
|
||||
# Wait for sandbox to be ready
|
||||
if not wait_for_sandbox_ready(info.sandbox_url, timeout=60):
|
||||
self._backend.destroy(info)
|
||||
raise RuntimeError(f"Sandbox {sandbox_id} failed to become ready within timeout at {info.sandbox_url}")
|
||||
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||
with self._lock:
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._sandbox_infos[sandbox_id] = info
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
if thread_id:
|
||||
self._thread_sandboxes[thread_id] = sandbox_id
|
||||
|
||||
logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
|
||||
return sandbox_id
|
||||
|
||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||
"""Get a sandbox by ID. Updates last activity timestamp.
|
||||
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox.
|
||||
|
||||
Returns:
|
||||
The sandbox instance if found, None otherwise.
|
||||
"""
|
||||
with self._lock:
|
||||
sandbox = self._sandboxes.get(sandbox_id)
|
||||
if sandbox is not None:
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
return sandbox
|
||||
|
||||
def release(self, sandbox_id: str) -> None:
|
||||
"""Release a sandbox from active use into the warm pool.
|
||||
|
||||
The container is kept running so it can be reclaimed quickly by the same
|
||||
thread on its next turn without a cold-start. The container will only be
|
||||
stopped when the replicas limit forces eviction or during shutdown.
|
||||
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox to release.
|
||||
"""
|
||||
info = None
|
||||
thread_ids_to_remove: list[str] = []
|
||||
|
||||
with self._lock:
|
||||
self._sandboxes.pop(sandbox_id, None)
|
||||
info = self._sandbox_infos.pop(sandbox_id, None)
|
||||
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||
for tid in thread_ids_to_remove:
|
||||
del self._thread_sandboxes[tid]
|
||||
self._last_activity.pop(sandbox_id, None)
|
||||
# Park in warm pool — container keeps running
|
||||
if info and sandbox_id not in self._warm_pool:
|
||||
self._warm_pool[sandbox_id] = (info, time.time())
|
||||
|
||||
logger.info(f"Released sandbox {sandbox_id} to warm pool (container still running)")
|
||||
|
||||
def destroy(self, sandbox_id: str) -> None:
|
||||
"""Destroy a sandbox: stop the container and free all resources.
|
||||
|
||||
Unlike release(), this actually stops the container. Use this for
|
||||
explicit cleanup, capacity-driven eviction, or shutdown.
|
||||
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox to destroy.
|
||||
"""
|
||||
info = None
|
||||
thread_ids_to_remove: list[str] = []
|
||||
|
||||
with self._lock:
|
||||
self._sandboxes.pop(sandbox_id, None)
|
||||
info = self._sandbox_infos.pop(sandbox_id, None)
|
||||
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||
for tid in thread_ids_to_remove:
|
||||
del self._thread_sandboxes[tid]
|
||||
self._last_activity.pop(sandbox_id, None)
|
||||
# Also pull from warm pool if it was parked there
|
||||
if info is None and sandbox_id in self._warm_pool:
|
||||
info, _ = self._warm_pool.pop(sandbox_id)
|
||||
else:
|
||||
self._warm_pool.pop(sandbox_id, None)
|
||||
|
||||
if info:
|
||||
self._backend.destroy(info)
|
||||
logger.info(f"Destroyed sandbox {sandbox_id}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown all sandboxes. Thread-safe and idempotent."""
|
||||
with self._lock:
|
||||
if self._shutdown_called:
|
||||
return
|
||||
self._shutdown_called = True
|
||||
sandbox_ids = list(self._sandboxes.keys())
|
||||
warm_items = list(self._warm_pool.items())
|
||||
self._warm_pool.clear()
|
||||
|
||||
# Stop idle checker
|
||||
self._idle_checker_stop.set()
|
||||
if self._idle_checker_thread is not None and self._idle_checker_thread.is_alive():
|
||||
self._idle_checker_thread.join(timeout=5)
|
||||
logger.info("Stopped idle checker thread")
|
||||
|
||||
logger.info(f"Shutting down {len(sandbox_ids)} active + {len(warm_items)} warm-pool sandbox(es)")
|
||||
|
||||
for sandbox_id in sandbox_ids:
|
||||
try:
|
||||
self.destroy(sandbox_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to destroy sandbox {sandbox_id} during shutdown: {e}")
|
||||
|
||||
for sandbox_id, (info, _) in warm_items:
|
||||
try:
|
||||
self._backend.destroy(info)
|
||||
logger.info(f"Destroyed warm-pool sandbox {sandbox_id} during shutdown")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to destroy warm-pool sandbox {sandbox_id} during shutdown: {e}")
|
||||
@@ -0,0 +1,114 @@
|
||||
"""Abstract base class for sandbox provisioning backends."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import requests
|
||||
|
||||
from .sandbox_info import SandboxInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def wait_for_sandbox_ready(sandbox_url: str, timeout: int = 30) -> bool:
|
||||
"""Poll sandbox health endpoint until ready or timeout.
|
||||
|
||||
Args:
|
||||
sandbox_url: URL of the sandbox (e.g. http://k3s:30001).
|
||||
timeout: Maximum time to wait in seconds.
|
||||
|
||||
Returns:
|
||||
True if sandbox is ready, False otherwise.
|
||||
"""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(f"{sandbox_url}/v1/sandbox", timeout=5)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(1)
|
||||
return False
|
||||
|
||||
|
||||
class SandboxBackend(ABC):
|
||||
"""Abstract base for sandbox provisioning backends.
|
||||
|
||||
Two implementations:
|
||||
- LocalContainerBackend: starts Docker/Apple Container locally, manages ports
|
||||
- RemoteSandboxBackend: connects to a pre-existing URL (K8s service, external)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
"""Create/provision a new sandbox.
|
||||
|
||||
Args:
|
||||
thread_id: Thread ID for which the sandbox is being created. Useful for backends that want to organize sandboxes by thread.
|
||||
sandbox_id: Deterministic sandbox identifier.
|
||||
extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples.
|
||||
Ignored by backends that don't manage containers (e.g., remote).
|
||||
|
||||
Returns:
|
||||
SandboxInfo with connection details.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def destroy(self, info: SandboxInfo) -> None:
|
||||
"""Destroy/cleanup a sandbox and release its resources.
|
||||
|
||||
Args:
|
||||
info: The sandbox metadata to destroy.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def is_alive(self, info: SandboxInfo) -> bool:
|
||||
"""Quick check whether a sandbox is still alive.
|
||||
|
||||
This should be a lightweight check (e.g., container inspect)
|
||||
rather than a full health check.
|
||||
|
||||
Args:
|
||||
info: The sandbox metadata to check.
|
||||
|
||||
Returns:
|
||||
True if the sandbox appears to be alive.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def discover(self, sandbox_id: str) -> SandboxInfo | None:
|
||||
"""Try to discover an existing sandbox by its deterministic ID.
|
||||
|
||||
Used for cross-process recovery: when another process started a sandbox,
|
||||
this process can discover it by the deterministic container name or URL.
|
||||
|
||||
Args:
|
||||
sandbox_id: The deterministic sandbox ID to look for.
|
||||
|
||||
Returns:
|
||||
SandboxInfo if found and healthy, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
def list_running(self) -> list[SandboxInfo]:
|
||||
"""Enumerate all running sandboxes managed by this backend.
|
||||
|
||||
Used for startup reconciliation: when the process restarts, it needs
|
||||
to discover containers started by previous processes so they can be
|
||||
adopted into the warm pool or destroyed if idle too long.
|
||||
|
||||
The default implementation returns an empty list, which is correct
|
||||
for backends that don't manage local containers (e.g., RemoteSandboxBackend
|
||||
delegates lifecycle to the provisioner which handles its own cleanup).
|
||||
|
||||
Returns:
|
||||
A list of SandboxInfo for all currently running sandboxes.
|
||||
"""
|
||||
return []
|
||||
@@ -0,0 +1,530 @@
|
||||
"""Local container backend for sandbox provisioning.
|
||||
|
||||
Manages sandbox containers using Docker or Apple Container on the local machine.
|
||||
Handles container lifecycle, port allocation, and cross-process container discovery.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
from deerflow.utils.network import get_free_port, release_port
|
||||
|
||||
from .backend import SandboxBackend, wait_for_sandbox_ready
|
||||
from .sandbox_info import SandboxInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_docker_timestamp(raw: str) -> float:
|
||||
"""Parse Docker's ISO 8601 timestamp into a Unix epoch float.
|
||||
|
||||
Docker returns timestamps with nanosecond precision and a trailing ``Z``
|
||||
(e.g. ``2026-04-08T01:22:50.123456789Z``). Python's ``fromisoformat``
|
||||
accepts at most microseconds and (pre-3.11) does not accept ``Z``, so the
|
||||
string is normalized before parsing. Returns ``0.0`` on empty input or
|
||||
parse failure so callers can use ``0.0`` as a sentinel for "unknown age".
|
||||
"""
|
||||
if not raw:
|
||||
return 0.0
|
||||
try:
|
||||
s = raw.strip()
|
||||
if "." in s:
|
||||
dot_pos = s.index(".")
|
||||
tz_start = dot_pos + 1
|
||||
while tz_start < len(s) and s[tz_start].isdigit():
|
||||
tz_start += 1
|
||||
frac = s[dot_pos + 1 : tz_start][:6] # truncate to microseconds
|
||||
tz_suffix = s[tz_start:]
|
||||
s = s[: dot_pos + 1] + frac + tz_suffix
|
||||
if s.endswith("Z"):
|
||||
s = s[:-1] + "+00:00"
|
||||
return datetime.fromisoformat(s).timestamp()
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.debug(f"Could not parse docker timestamp {raw!r}: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def _extract_host_port(inspect_entry: dict, container_port: int) -> int | None:
|
||||
"""Extract the host port mapped to ``container_port/tcp`` from a docker inspect entry.
|
||||
|
||||
Returns None if the container has no port mapping for that port.
|
||||
"""
|
||||
try:
|
||||
ports = (inspect_entry.get("NetworkSettings") or {}).get("Ports") or {}
|
||||
bindings = ports.get(f"{container_port}/tcp") or []
|
||||
if bindings:
|
||||
host_port = bindings[0].get("HostPort")
|
||||
if host_port:
|
||||
return int(host_port)
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _format_container_mount(runtime: str, host_path: str, container_path: str, read_only: bool) -> list[str]:
|
||||
"""Format a bind-mount argument for the selected runtime.
|
||||
|
||||
Docker's ``-v host:container`` syntax is ambiguous for Windows drive-letter
|
||||
paths like ``D:/...`` because ``:`` is both the drive separator and the
|
||||
volume separator. Use ``--mount type=bind,...`` for Docker to avoid that
|
||||
parsing ambiguity. Apple Container keeps using ``-v``.
|
||||
"""
|
||||
if runtime == "docker":
|
||||
mount_spec = f"type=bind,src={host_path},dst={container_path}"
|
||||
if read_only:
|
||||
mount_spec += ",readonly"
|
||||
return ["--mount", mount_spec]
|
||||
|
||||
mount_spec = f"{host_path}:{container_path}"
|
||||
if read_only:
|
||||
mount_spec += ":ro"
|
||||
return ["-v", mount_spec]
|
||||
|
||||
|
||||
class LocalContainerBackend(SandboxBackend):
|
||||
"""Backend that manages sandbox containers locally using Docker or Apple Container.
|
||||
|
||||
On macOS, automatically prefers Apple Container if available, otherwise falls back to Docker.
|
||||
On other platforms, uses Docker.
|
||||
|
||||
Features:
|
||||
- Deterministic container naming for cross-process discovery
|
||||
- Port allocation with thread-safe utilities
|
||||
- Container lifecycle management (start/stop with --rm)
|
||||
- Support for volume mounts and environment variables
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image: str,
|
||||
base_port: int,
|
||||
container_prefix: str,
|
||||
config_mounts: list,
|
||||
environment: dict[str, str],
|
||||
):
|
||||
"""Initialize the local container backend.
|
||||
|
||||
Args:
|
||||
image: Container image to use.
|
||||
base_port: Base port number to start searching for free ports.
|
||||
container_prefix: Prefix for container names (e.g., "deer-flow-sandbox").
|
||||
config_mounts: Volume mount configurations from config (list of VolumeMountConfig).
|
||||
environment: Environment variables to inject into containers.
|
||||
"""
|
||||
self._image = image
|
||||
self._base_port = base_port
|
||||
self._container_prefix = container_prefix
|
||||
self._config_mounts = config_mounts
|
||||
self._environment = environment
|
||||
self._runtime = self._detect_runtime()
|
||||
|
||||
@property
|
||||
def runtime(self) -> str:
|
||||
"""The detected container runtime ("docker" or "container")."""
|
||||
return self._runtime
|
||||
|
||||
def _detect_runtime(self) -> str:
|
||||
"""Detect which container runtime to use.
|
||||
|
||||
On macOS, prefer Apple Container if available, otherwise fall back to Docker.
|
||||
On other platforms, use Docker.
|
||||
|
||||
Returns:
|
||||
"container" for Apple Container, "docker" for Docker.
|
||||
"""
|
||||
import platform
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["container", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
timeout=5,
|
||||
)
|
||||
logger.info(f"Detected Apple Container: {result.stdout.strip()}")
|
||||
return "container"
|
||||
except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired):
|
||||
logger.info("Apple Container not available, falling back to Docker")
|
||||
|
||||
return "docker"
|
||||
|
||||
# ── SandboxBackend interface ──────────────────────────────────────────
|
||||
|
||||
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
"""Start a new container and return its connection info.
|
||||
|
||||
Args:
|
||||
thread_id: Thread ID for which the sandbox is being created. Useful for backends that want to organize sandboxes by thread.
|
||||
sandbox_id: Deterministic sandbox identifier (used in container name).
|
||||
extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples.
|
||||
|
||||
Returns:
|
||||
SandboxInfo with container details.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the container fails to start.
|
||||
"""
|
||||
container_name = f"{self._container_prefix}-{sandbox_id}"
|
||||
|
||||
# Retry loop: if Docker rejects the port (e.g. a stale container still
|
||||
# holds the binding after a process restart), skip that port and try the
|
||||
# next one. The socket-bind check in get_free_port mirrors Docker's
|
||||
# 0.0.0.0 bind, but Docker's port-release can be slightly asynchronous,
|
||||
# so a reactive fallback here ensures we always make progress.
|
||||
_next_start = self._base_port
|
||||
container_id: str | None = None
|
||||
port: int = 0
|
||||
for _attempt in range(10):
|
||||
port = get_free_port(start_port=_next_start)
|
||||
try:
|
||||
container_id = self._start_container(container_name, port, extra_mounts)
|
||||
break
|
||||
except RuntimeError as exc:
|
||||
release_port(port)
|
||||
err = str(exc)
|
||||
err_lower = err.lower()
|
||||
# Port already bound: skip this port and retry with the next one.
|
||||
if "port is already allocated" in err or "address already in use" in err_lower:
|
||||
logger.warning(f"Port {port} rejected by Docker (already allocated), retrying with next port")
|
||||
_next_start = port + 1
|
||||
continue
|
||||
# Container-name conflict: another process may have already started
|
||||
# the deterministic sandbox container for this sandbox_id. Try to
|
||||
# discover and adopt the existing container instead of failing.
|
||||
if "is already in use by container" in err_lower or "conflict. the container name" in err_lower:
|
||||
logger.warning(f"Container name {container_name} already in use, attempting to discover existing sandbox instance")
|
||||
existing = self.discover(sandbox_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
raise
|
||||
else:
|
||||
raise RuntimeError("Could not start sandbox container: all candidate ports are already allocated by Docker")
|
||||
|
||||
# When running inside Docker (DooD), sandbox containers are reachable via
|
||||
# host.docker.internal rather than localhost (they run on the host daemon).
|
||||
sandbox_host = os.environ.get("DEER_FLOW_SANDBOX_HOST", "localhost")
|
||||
return SandboxInfo(
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_url=f"http://{sandbox_host}:{port}",
|
||||
container_name=container_name,
|
||||
container_id=container_id,
|
||||
)
|
||||
|
||||
def destroy(self, info: SandboxInfo) -> None:
|
||||
"""Stop the container and release its port."""
|
||||
# Prefer container_id, fall back to container_name (both accepted by docker stop).
|
||||
# This ensures containers discovered via list_running() (which only has the name)
|
||||
# can also be stopped.
|
||||
stop_target = info.container_id or info.container_name
|
||||
if stop_target:
|
||||
self._stop_container(stop_target)
|
||||
# Extract port from sandbox_url for release
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
port = urlparse(info.sandbox_url).port
|
||||
if port:
|
||||
release_port(port)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def is_alive(self, info: SandboxInfo) -> bool:
|
||||
"""Check if the container is still running (lightweight, no HTTP)."""
|
||||
if info.container_name:
|
||||
return self._is_container_running(info.container_name)
|
||||
return False
|
||||
|
||||
def discover(self, sandbox_id: str) -> SandboxInfo | None:
|
||||
"""Discover an existing container by its deterministic name.
|
||||
|
||||
Checks if a container with the expected name is running, retrieves its
|
||||
port, and verifies it responds to health checks.
|
||||
|
||||
Args:
|
||||
sandbox_id: The deterministic sandbox ID (determines container name).
|
||||
|
||||
Returns:
|
||||
SandboxInfo if container found and healthy, None otherwise.
|
||||
"""
|
||||
container_name = f"{self._container_prefix}-{sandbox_id}"
|
||||
|
||||
if not self._is_container_running(container_name):
|
||||
return None
|
||||
|
||||
port = self._get_container_port(container_name)
|
||||
if port is None:
|
||||
return None
|
||||
|
||||
sandbox_host = os.environ.get("DEER_FLOW_SANDBOX_HOST", "localhost")
|
||||
sandbox_url = f"http://{sandbox_host}:{port}"
|
||||
if not wait_for_sandbox_ready(sandbox_url, timeout=5):
|
||||
return None
|
||||
|
||||
return SandboxInfo(
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_url=sandbox_url,
|
||||
container_name=container_name,
|
||||
)
|
||||
|
||||
def list_running(self) -> list[SandboxInfo]:
|
||||
"""Enumerate all running containers matching the configured prefix.
|
||||
|
||||
Uses a single ``docker ps`` call to list container names, then a
|
||||
single batched ``docker inspect`` call to retrieve creation timestamp
|
||||
and port mapping for all containers at once. Total subprocess calls:
|
||||
2 (down from 2N+1 in the naive per-container approach).
|
||||
|
||||
Note: Docker's ``--filter name=`` performs *substring* matching,
|
||||
so a secondary ``startswith`` check is applied to ensure only
|
||||
containers with the exact prefix are included.
|
||||
|
||||
Containers without port mappings are still included (with empty
|
||||
sandbox_url) so that startup reconciliation can adopt orphans
|
||||
regardless of their port state.
|
||||
"""
|
||||
# Step 1: enumerate container names via docker ps
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
self._runtime,
|
||||
"ps",
|
||||
"--filter",
|
||||
f"name={self._container_prefix}-",
|
||||
"--format",
|
||||
"{{.Names}}",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
stderr = (result.stderr or "").strip()
|
||||
logger.warning(
|
||||
"Failed to list running containers with %s ps (returncode=%s, stderr=%s)",
|
||||
self._runtime,
|
||||
result.returncode,
|
||||
stderr or "<empty>",
|
||||
)
|
||||
return []
|
||||
if not result.stdout.strip():
|
||||
return []
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError, OSError) as e:
|
||||
logger.warning(f"Failed to list running containers: {e}")
|
||||
return []
|
||||
|
||||
# Filter to names matching our exact prefix (docker filter is substring-based)
|
||||
container_names = [name.strip() for name in result.stdout.strip().splitlines() if name.strip().startswith(self._container_prefix + "-")]
|
||||
if not container_names:
|
||||
return []
|
||||
|
||||
# Step 2: batched docker inspect — single subprocess call for all containers
|
||||
inspections = self._batch_inspect(container_names)
|
||||
|
||||
infos: list[SandboxInfo] = []
|
||||
sandbox_host = os.environ.get("DEER_FLOW_SANDBOX_HOST", "localhost")
|
||||
for container_name in container_names:
|
||||
data = inspections.get(container_name)
|
||||
if data is None:
|
||||
# Container disappeared between ps and inspect, or inspect failed
|
||||
continue
|
||||
created_at, host_port = data
|
||||
sandbox_id = container_name[len(self._container_prefix) + 1 :]
|
||||
sandbox_url = f"http://{sandbox_host}:{host_port}" if host_port else ""
|
||||
|
||||
infos.append(
|
||||
SandboxInfo(
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_url=sandbox_url,
|
||||
container_name=container_name,
|
||||
created_at=created_at,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(infos)} running sandbox container(s)")
|
||||
return infos
|
||||
|
||||
def _batch_inspect(self, container_names: list[str]) -> dict[str, tuple[float, int | None]]:
|
||||
"""Batch-inspect containers in a single subprocess call.
|
||||
|
||||
Returns a mapping of ``container_name -> (created_at, host_port)``.
|
||||
Missing containers or parse failures are silently dropped from the result.
|
||||
"""
|
||||
if not container_names:
|
||||
return {}
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[self._runtime, "inspect", *container_names],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=15,
|
||||
)
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError, OSError) as e:
|
||||
logger.warning(f"Failed to batch-inspect containers: {e}")
|
||||
return {}
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr = (result.stderr or "").strip()
|
||||
logger.warning(
|
||||
"Failed to batch-inspect containers with %s inspect (returncode=%s, stderr=%s)",
|
||||
self._runtime,
|
||||
result.returncode,
|
||||
stderr or "<empty>",
|
||||
)
|
||||
return {}
|
||||
|
||||
try:
|
||||
payload = json.loads(result.stdout or "[]")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse docker inspect output as JSON: {e}")
|
||||
return {}
|
||||
|
||||
out: dict[str, tuple[float, int | None]] = {}
|
||||
for entry in payload:
|
||||
# ``Name`` is prefixed with ``/`` in the docker inspect response
|
||||
name = (entry.get("Name") or "").lstrip("/")
|
||||
if not name:
|
||||
continue
|
||||
created_at = _parse_docker_timestamp(entry.get("Created", ""))
|
||||
host_port = _extract_host_port(entry, 8080)
|
||||
out[name] = (created_at, host_port)
|
||||
return out
|
||||
|
||||
# ── Container operations ─────────────────────────────────────────────
|
||||
|
||||
def _start_container(
|
||||
self,
|
||||
container_name: str,
|
||||
port: int,
|
||||
extra_mounts: list[tuple[str, str, bool]] | None = None,
|
||||
) -> str:
|
||||
"""Start a new container.
|
||||
|
||||
Args:
|
||||
container_name: Name for the container.
|
||||
port: Host port to map to container port 8080.
|
||||
extra_mounts: Additional volume mounts.
|
||||
|
||||
Returns:
|
||||
The container ID.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If container fails to start.
|
||||
"""
|
||||
cmd = [self._runtime, "run"]
|
||||
|
||||
# Docker-specific security options
|
||||
if self._runtime == "docker":
|
||||
cmd.extend(["--security-opt", "seccomp=unconfined"])
|
||||
|
||||
cmd.extend(
|
||||
[
|
||||
"--rm",
|
||||
"-d",
|
||||
"-p",
|
||||
f"{port}:8080",
|
||||
"--name",
|
||||
container_name,
|
||||
]
|
||||
)
|
||||
|
||||
# Environment variables
|
||||
for key, value in self._environment.items():
|
||||
cmd.extend(["-e", f"{key}={value}"])
|
||||
|
||||
# Config-level volume mounts
|
||||
for mount in self._config_mounts:
|
||||
cmd.extend(
|
||||
_format_container_mount(
|
||||
self._runtime,
|
||||
mount.host_path,
|
||||
mount.container_path,
|
||||
mount.read_only,
|
||||
)
|
||||
)
|
||||
|
||||
# Extra mounts (thread-specific, skills, etc.)
|
||||
if extra_mounts:
|
||||
for host_path, container_path, read_only in extra_mounts:
|
||||
cmd.extend(
|
||||
_format_container_mount(
|
||||
self._runtime,
|
||||
host_path,
|
||||
container_path,
|
||||
read_only,
|
||||
)
|
||||
)
|
||||
|
||||
cmd.append(self._image)
|
||||
|
||||
logger.info(f"Starting container using {self._runtime}: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
container_id = result.stdout.strip()
|
||||
logger.info(f"Started container {container_name} (ID: {container_id}) using {self._runtime}")
|
||||
return container_id
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Failed to start container using {self._runtime}: {e.stderr}")
|
||||
raise RuntimeError(f"Failed to start sandbox container: {e.stderr}")
|
||||
|
||||
def _stop_container(self, container_id: str) -> None:
|
||||
"""Stop a container (--rm ensures automatic removal)."""
|
||||
try:
|
||||
subprocess.run(
|
||||
[self._runtime, "stop", container_id],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
logger.info(f"Stopped container {container_id} using {self._runtime}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(f"Failed to stop container {container_id}: {e.stderr}")
|
||||
|
||||
def _is_container_running(self, container_name: str) -> bool:
|
||||
"""Check if a named container is currently running.
|
||||
|
||||
This enables cross-process container discovery — any process can detect
|
||||
containers started by another process via the deterministic container name.
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[self._runtime, "inspect", "-f", "{{.State.Running}}", container_name],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.returncode == 0 and result.stdout.strip().lower() == "true"
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired):
|
||||
return False
|
||||
|
||||
def _get_container_port(self, container_name: str) -> int | None:
|
||||
"""Get the host port of a running container.
|
||||
|
||||
Args:
|
||||
container_name: The container name to inspect.
|
||||
|
||||
Returns:
|
||||
The host port mapped to container port 8080, or None if not found.
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[self._runtime, "port", container_name, "8080"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
# Output format: "0.0.0.0:PORT" or ":::PORT"
|
||||
port_str = result.stdout.strip().split(":")[-1]
|
||||
return int(port_str)
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, ValueError):
|
||||
pass
|
||||
return None
|
||||
@@ -0,0 +1,156 @@
|
||||
"""Remote sandbox backend — delegates Pod lifecycle to the provisioner service.
|
||||
|
||||
The provisioner dynamically creates per-sandbox-id Pods + NodePort Services
|
||||
in k3s. The backend accesses sandbox pods directly via ``k3s:{NodePort}``.
|
||||
|
||||
Architecture:
|
||||
┌────────────┐ HTTP ┌─────────────┐ K8s API ┌──────────┐
|
||||
│ this file │ ──────▸ │ provisioner │ ────────▸ │ k3s │
|
||||
│ (backend) │ │ :8002 │ │ :6443 │
|
||||
└────────────┘ └─────────────┘ └─────┬────┘
|
||||
│ creates
|
||||
┌─────────────┐ ┌─────▼──────┐
|
||||
│ backend │ ────────▸ │ sandbox │
|
||||
│ │ direct │ Pod(s) │
|
||||
└─────────────┘ k3s:NPort └────────────┘
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from .backend import SandboxBackend
|
||||
from .sandbox_info import SandboxInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteSandboxBackend(SandboxBackend):
|
||||
"""Backend that delegates sandbox lifecycle to the provisioner service.
|
||||
|
||||
All Pod creation, destruction, and discovery are handled by the
|
||||
provisioner. This backend is a thin HTTP client.
|
||||
|
||||
Typical config.yaml::
|
||||
|
||||
sandbox:
|
||||
use: deerflow.community.aio_sandbox:AioSandboxProvider
|
||||
provisioner_url: http://provisioner:8002
|
||||
"""
|
||||
|
||||
def __init__(self, provisioner_url: str):
|
||||
"""Initialize with the provisioner service URL.
|
||||
|
||||
Args:
|
||||
provisioner_url: URL of the provisioner service
|
||||
(e.g., ``http://provisioner:8002``).
|
||||
"""
|
||||
self._provisioner_url = provisioner_url.rstrip("/")
|
||||
|
||||
@property
|
||||
def provisioner_url(self) -> str:
|
||||
return self._provisioner_url
|
||||
|
||||
# ── SandboxBackend interface ──────────────────────────────────────────
|
||||
|
||||
def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
sandbox_id: str,
|
||||
extra_mounts: list[tuple[str, str, bool]] | None = None,
|
||||
) -> SandboxInfo:
|
||||
"""Create a sandbox Pod + Service via the provisioner.
|
||||
|
||||
Calls ``POST /api/sandboxes`` which creates a dedicated Pod +
|
||||
NodePort Service in k3s.
|
||||
"""
|
||||
return self._provisioner_create(thread_id, sandbox_id, extra_mounts)
|
||||
|
||||
def destroy(self, info: SandboxInfo) -> None:
|
||||
"""Destroy a sandbox Pod + Service via the provisioner."""
|
||||
self._provisioner_destroy(info.sandbox_id)
|
||||
|
||||
def is_alive(self, info: SandboxInfo) -> bool:
|
||||
"""Check whether the sandbox Pod is running."""
|
||||
return self._provisioner_is_alive(info.sandbox_id)
|
||||
|
||||
def discover(self, sandbox_id: str) -> SandboxInfo | None:
|
||||
"""Discover an existing sandbox via the provisioner.
|
||||
|
||||
Calls ``GET /api/sandboxes/{sandbox_id}`` and returns info if
|
||||
the Pod exists.
|
||||
"""
|
||||
return self._provisioner_discover(sandbox_id)
|
||||
|
||||
# ── Provisioner API calls ─────────────────────────────────────────────
|
||||
|
||||
def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
"""POST /api/sandboxes → create Pod + Service."""
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"{self._provisioner_url}/api/sandboxes",
|
||||
json={
|
||||
"sandbox_id": sandbox_id,
|
||||
"thread_id": thread_id,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
logger.info(f"Provisioner created sandbox {sandbox_id}: sandbox_url={data['sandbox_url']}")
|
||||
return SandboxInfo(
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_url=data["sandbox_url"],
|
||||
)
|
||||
except requests.RequestException as exc:
|
||||
logger.error(f"Provisioner create failed for {sandbox_id}: {exc}")
|
||||
raise RuntimeError(f"Provisioner create failed: {exc}") from exc
|
||||
|
||||
def _provisioner_destroy(self, sandbox_id: str) -> None:
|
||||
"""DELETE /api/sandboxes/{sandbox_id} → destroy Pod + Service."""
|
||||
try:
|
||||
resp = requests.delete(
|
||||
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
|
||||
timeout=15,
|
||||
)
|
||||
if resp.ok:
|
||||
logger.info(f"Provisioner destroyed sandbox {sandbox_id}")
|
||||
else:
|
||||
logger.warning(f"Provisioner destroy returned {resp.status_code}: {resp.text}")
|
||||
except requests.RequestException as exc:
|
||||
logger.warning(f"Provisioner destroy failed for {sandbox_id}: {exc}")
|
||||
|
||||
def _provisioner_is_alive(self, sandbox_id: str) -> bool:
|
||||
"""GET /api/sandboxes/{sandbox_id} → check Pod phase."""
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
|
||||
timeout=10,
|
||||
)
|
||||
if resp.ok:
|
||||
data = resp.json()
|
||||
return data.get("status") == "Running"
|
||||
return False
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
def _provisioner_discover(self, sandbox_id: str) -> SandboxInfo | None:
|
||||
"""GET /api/sandboxes/{sandbox_id} → discover existing sandbox."""
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
|
||||
timeout=10,
|
||||
)
|
||||
if resp.status_code == 404:
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return SandboxInfo(
|
||||
sandbox_id=sandbox_id,
|
||||
sandbox_url=data["sandbox_url"],
|
||||
)
|
||||
except requests.RequestException as exc:
|
||||
logger.debug(f"Provisioner discover failed for {sandbox_id}: {exc}")
|
||||
return None
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Sandbox metadata for cross-process discovery and state persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SandboxInfo:
|
||||
"""Persisted sandbox metadata that enables cross-process discovery.
|
||||
|
||||
This dataclass holds all the information needed to reconnect to an
|
||||
existing sandbox from a different process (e.g., gateway vs langgraph,
|
||||
multiple workers, or across K8s pods with shared storage).
|
||||
"""
|
||||
|
||||
sandbox_id: str
|
||||
sandbox_url: str # e.g. http://localhost:8080 or http://k3s:30001
|
||||
container_name: str | None = None # Only for local container backend
|
||||
container_id: str | None = None # Only for local container backend
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"sandbox_id": self.sandbox_id,
|
||||
"sandbox_url": self.sandbox_url,
|
||||
"container_name": self.container_name,
|
||||
"container_id": self.container_id,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> SandboxInfo:
|
||||
return cls(
|
||||
sandbox_id=data["sandbox_id"],
|
||||
sandbox_url=data.get("sandbox_url", data.get("base_url", "")),
|
||||
container_name=data.get("container_name"),
|
||||
container_id=data.get("container_id"),
|
||||
created_at=data.get("created_at", time.time()),
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tools import web_search_tool
|
||||
|
||||
__all__ = ["web_search_tool"]
|
||||
@@ -0,0 +1,5 @@
|
||||
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
|
||||
|
||||
from deerflow.community._disabled_native import reject_native_provider
|
||||
|
||||
reject_native_provider("ddg_search")
|
||||
@@ -0,0 +1,5 @@
|
||||
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
|
||||
|
||||
from deerflow.community._disabled_native import reject_native_provider
|
||||
|
||||
reject_native_provider("exa")
|
||||
@@ -0,0 +1,5 @@
|
||||
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
|
||||
|
||||
from deerflow.community._disabled_native import reject_native_provider
|
||||
|
||||
reject_native_provider("firecrawl")
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tools import image_search_tool
|
||||
|
||||
__all__ = ["image_search_tool"]
|
||||
@@ -0,0 +1,5 @@
|
||||
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
|
||||
|
||||
from deerflow.community._disabled_native import reject_native_provider
|
||||
|
||||
reject_native_provider("image_search")
|
||||
@@ -0,0 +1,5 @@
|
||||
"""DISABLED: native InfoQuest client. See deerflow.community._disabled_native."""
|
||||
|
||||
from deerflow.community._disabled_native import reject_native_provider
|
||||
|
||||
reject_native_provider("infoquest")
|
||||
@@ -0,0 +1,5 @@
|
||||
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
|
||||
|
||||
from deerflow.community._disabled_native import reject_native_provider
|
||||
|
||||
reject_native_provider("infoquest")
|
||||
@@ -0,0 +1,5 @@
|
||||
"""DISABLED: native Jina client. See deerflow.community._disabled_native."""
|
||||
|
||||
from deerflow.community._disabled_native import reject_native_provider
|
||||
|
||||
reject_native_provider("jina_ai")
|
||||
@@ -0,0 +1,5 @@
|
||||
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
|
||||
|
||||
from deerflow.community._disabled_native import reject_native_provider
|
||||
|
||||
reject_native_provider("jina_ai")
|
||||
@@ -0,0 +1,12 @@
|
||||
"""Hardened SearX-backed web tools (search, fetch, image search).
|
||||
|
||||
All results are passed through the deerflow.security pipeline:
|
||||
- HTML cleaning (strip script/style/iframe/etc.)
|
||||
- Unicode sanitizer (zero-width chars, control chars, PUA, tag chars)
|
||||
- Content delimiter wrapping (semantic boundary for the LLM)
|
||||
|
||||
These tools are the ONLY web access surface in this hardened build.
|
||||
The legacy community web providers (ddg_search, tavily, exa, firecrawl,
|
||||
jina_ai, infoquest, image_search) are deliberately disabled — see
|
||||
deerflow/community/_disabled_native.py.
|
||||
"""
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Hardened SearX web search, web fetch, and image search tools.
|
||||
|
||||
Every external response is sanitized and wrapped in security delimiters
|
||||
before being returned to the LLM. See deerflow.security for the pipeline.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.security.content_delimiter import wrap_untrusted_content
|
||||
from deerflow.security.html_cleaner import extract_secure_text
|
||||
from deerflow.security.sanitizer import sanitizer
|
||||
|
||||
DEFAULT_SEARX_URL = "http://localhost:8888"
|
||||
DEFAULT_TIMEOUT = 30.0
|
||||
DEFAULT_USER_AGENT = "DeerFlow-Hardened/1.0 (+searx)"
|
||||
|
||||
|
||||
def _tool_extra(name: str) -> dict:
|
||||
"""Read the model_extra dict for a tool config entry, defensively."""
|
||||
cfg = get_app_config().get_tool_config(name)
|
||||
if cfg is None:
|
||||
return {}
|
||||
return getattr(cfg, "model_extra", {}) or {}
|
||||
|
||||
|
||||
def _searx_url(tool_name: str = "web_search") -> str:
|
||||
return _tool_extra(tool_name).get("searx_url", DEFAULT_SEARX_URL)
|
||||
|
||||
|
||||
def _http_get(url: str, params: dict, timeout: float = DEFAULT_TIMEOUT) -> dict:
|
||||
"""GET a SearX endpoint and return parsed JSON. Raises on transport/HTTP error."""
|
||||
with httpx.Client(headers={"User-Agent": DEFAULT_USER_AGENT}) as client:
|
||||
response = client.get(url, params=params, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str, max_results: int = 10) -> str:
|
||||
"""Search the web via the private hardened SearX instance.
|
||||
|
||||
All results are sanitized against prompt-injection vectors and
|
||||
wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>> markers.
|
||||
|
||||
Args:
|
||||
query: Search keywords.
|
||||
max_results: Maximum results to return (capped by config).
|
||||
"""
|
||||
extra = _tool_extra("web_search")
|
||||
cap = int(extra.get("max_results", 10))
|
||||
searx_url = extra.get("searx_url", DEFAULT_SEARX_URL)
|
||||
limit = max(1, min(int(max_results), cap))
|
||||
|
||||
try:
|
||||
data = _http_get(
|
||||
f"{searx_url}/search",
|
||||
{"q": quote(query), "format": "json"},
|
||||
)
|
||||
except Exception as exc:
|
||||
return wrap_untrusted_content({"error": f"Search failed: {exc}"})
|
||||
|
||||
results = []
|
||||
for item in data.get("results", [])[:limit]:
|
||||
results.append(
|
||||
{
|
||||
"title": sanitizer.sanitize(item.get("title", ""), max_length=200),
|
||||
"url": item.get("url", ""),
|
||||
"content": sanitizer.sanitize(item.get("content", ""), max_length=500),
|
||||
}
|
||||
)
|
||||
|
||||
return wrap_untrusted_content(
|
||||
{
|
||||
"query": query,
|
||||
"total_results": len(results),
|
||||
"results": results,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
async def web_fetch_tool(url: str, max_chars: int = 10000) -> str:
|
||||
"""Fetch a web page and return sanitized visible text.
|
||||
|
||||
Dangerous HTML elements (script, style, iframe, form, ...) are stripped,
|
||||
invisible Unicode is removed, and the result is wrapped in security markers.
|
||||
Only call this for URLs returned by web_search or supplied directly by the
|
||||
user — do not invent URLs.
|
||||
|
||||
Args:
|
||||
url: Absolute URL to fetch (must include scheme).
|
||||
max_chars: Maximum number of characters to return.
|
||||
"""
|
||||
extra = _tool_extra("web_fetch")
|
||||
cap = int(extra.get("max_chars", max_chars))
|
||||
limit = max(256, min(int(max_chars), cap))
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
headers={"User-Agent": DEFAULT_USER_AGENT},
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
response = await client.get(url, timeout=DEFAULT_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
html = response.text
|
||||
except Exception as exc:
|
||||
return wrap_untrusted_content({"error": f"Fetch failed: {exc}", "url": url})
|
||||
|
||||
raw_text = extract_secure_text(html)
|
||||
clean_text = sanitizer.sanitize(raw_text, max_length=limit)
|
||||
return wrap_untrusted_content({"url": url, "content": clean_text})
|
||||
|
||||
|
||||
@tool("image_search", parse_docstring=True)
|
||||
def image_search_tool(query: str, max_results: int = 5) -> str:
|
||||
"""Search for images via the private hardened SearX instance.
|
||||
|
||||
Returns sanitized title/url pairs (no inline image data). Wrapped in
|
||||
security delimiters.
|
||||
|
||||
Args:
|
||||
query: Image search keywords.
|
||||
max_results: Maximum number of images to return.
|
||||
"""
|
||||
extra = _tool_extra("image_search")
|
||||
cap = int(extra.get("max_results", 5))
|
||||
searx_url = extra.get("searx_url", _searx_url("web_search"))
|
||||
limit = max(1, min(int(max_results), cap))
|
||||
|
||||
try:
|
||||
data = _http_get(
|
||||
f"{searx_url}/search",
|
||||
{"q": quote(query), "format": "json", "categories": "images"},
|
||||
)
|
||||
except Exception as exc:
|
||||
return wrap_untrusted_content({"error": f"Image search failed: {exc}"})
|
||||
|
||||
results = []
|
||||
for item in data.get("results", [])[:limit]:
|
||||
results.append(
|
||||
{
|
||||
"title": sanitizer.sanitize(item.get("title", ""), max_length=200),
|
||||
"url": item.get("url", ""),
|
||||
"thumbnail": item.get("thumbnail_src") or item.get("img_src", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return wrap_untrusted_content(
|
||||
{
|
||||
"query": query,
|
||||
"total_results": len(results),
|
||||
"results": results,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""DISABLED: native web tool. See deerflow.community._disabled_native."""
|
||||
|
||||
from deerflow.community._disabled_native import reject_native_provider
|
||||
|
||||
reject_native_provider("tavily")
|
||||
@@ -0,0 +1,30 @@
|
||||
from .app_config import get_app_config
|
||||
from .extensions_config import ExtensionsConfig, get_extensions_config
|
||||
from .memory_config import MemoryConfig, get_memory_config
|
||||
from .paths import Paths, get_paths
|
||||
from .skill_evolution_config import SkillEvolutionConfig
|
||||
from .skills_config import SkillsConfig
|
||||
from .tracing_config import (
|
||||
get_enabled_tracing_providers,
|
||||
get_explicitly_enabled_tracing_providers,
|
||||
get_tracing_config,
|
||||
is_tracing_enabled,
|
||||
validate_enabled_tracing_providers,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_app_config",
|
||||
"SkillEvolutionConfig",
|
||||
"Paths",
|
||||
"get_paths",
|
||||
"SkillsConfig",
|
||||
"ExtensionsConfig",
|
||||
"get_extensions_config",
|
||||
"MemoryConfig",
|
||||
"get_memory_config",
|
||||
"get_tracing_config",
|
||||
"get_explicitly_enabled_tracing_providers",
|
||||
"get_enabled_tracing_providers",
|
||||
"is_tracing_enabled",
|
||||
"validate_enabled_tracing_providers",
|
||||
]
|
||||
@@ -0,0 +1,51 @@
|
||||
"""ACP (Agent Client Protocol) agent configuration loaded from config.yaml."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ACPAgentConfig(BaseModel):
|
||||
"""Configuration for a single ACP-compatible agent."""
|
||||
|
||||
command: str = Field(description="Command to launch the ACP agent subprocess")
|
||||
args: list[str] = Field(default_factory=list, description="Additional command arguments")
|
||||
env: dict[str, str] = Field(default_factory=dict, description="Environment variables to inject into the agent subprocess. Values starting with $ are resolved from host environment variables.")
|
||||
description: str = Field(description="Description of the agent's capabilities (shown in tool description)")
|
||||
model: str | None = Field(default=None, description="Model hint passed to the agent (optional)")
|
||||
auto_approve_permissions: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"When True, DeerFlow automatically approves all ACP permission requests from this agent "
|
||||
"(allow_once preferred over allow_always). When False (default), all permission requests "
|
||||
"are denied — the agent must be configured to operate without requesting permissions."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_acp_agents: dict[str, ACPAgentConfig] = {}
|
||||
|
||||
|
||||
def get_acp_agents() -> dict[str, ACPAgentConfig]:
|
||||
"""Get the currently configured ACP agents.
|
||||
|
||||
Returns:
|
||||
Mapping of agent name -> ACPAgentConfig. Empty dict if no ACP agents are configured.
|
||||
"""
|
||||
return _acp_agents
|
||||
|
||||
|
||||
def load_acp_config_from_dict(config_dict: Mapping[str, Mapping[str, object]] | None) -> None:
|
||||
"""Load ACP agent configuration from a dictionary (typically from config.yaml).
|
||||
|
||||
Args:
|
||||
config_dict: Mapping of agent name -> config fields.
|
||||
"""
|
||||
global _acp_agents
|
||||
if config_dict is None:
|
||||
config_dict = {}
|
||||
_acp_agents = {name: ACPAgentConfig(**cfg) for name, cfg in config_dict.items()}
|
||||
logger.info("ACP config loaded: %d agent(s): %s", len(_acp_agents), list(_acp_agents.keys()))
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Configuration and loaders for custom agents."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SOUL_FILENAME = "SOUL.md"
|
||||
AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
"""Configuration for a custom agent."""
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
model: str | None = None
|
||||
tool_groups: list[str] | None = None
|
||||
# skills controls which skills are loaded into the agent's prompt:
|
||||
# - None (or omitted): load all enabled skills (default fallback behavior)
|
||||
# - [] (explicit empty list): disable all skills
|
||||
# - ["skill1", "skill2"]: load only the specified skills
|
||||
skills: list[str] | None = None
|
||||
|
||||
|
||||
def load_agent_config(name: str | None) -> AgentConfig | None:
|
||||
"""Load the custom or default agent's config from its directory.
|
||||
|
||||
Args:
|
||||
name: The agent name.
|
||||
|
||||
Returns:
|
||||
AgentConfig instance.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the agent directory or config.yaml does not exist.
|
||||
ValueError: If config.yaml cannot be parsed.
|
||||
"""
|
||||
|
||||
if name is None:
|
||||
return None
|
||||
|
||||
if not AGENT_NAME_PATTERN.match(name):
|
||||
raise ValueError(f"Invalid agent name '{name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
|
||||
agent_dir = get_paths().agent_dir(name)
|
||||
config_file = agent_dir / "config.yaml"
|
||||
|
||||
if not agent_dir.exists():
|
||||
raise FileNotFoundError(f"Agent directory not found: {agent_dir}")
|
||||
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"Agent config not found: {config_file}")
|
||||
|
||||
try:
|
||||
with open(config_file, encoding="utf-8") as f:
|
||||
data: dict[str, Any] = yaml.safe_load(f) or {}
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Failed to parse agent config {config_file}: {e}") from e
|
||||
|
||||
# Ensure name is set from directory name if not in file
|
||||
if "name" not in data:
|
||||
data["name"] = name
|
||||
|
||||
# Strip unknown fields before passing to Pydantic (e.g. legacy prompt_file)
|
||||
known_fields = set(AgentConfig.model_fields.keys())
|
||||
data = {k: v for k, v in data.items() if k in known_fields}
|
||||
|
||||
return AgentConfig(**data)
|
||||
|
||||
|
||||
def load_agent_soul(agent_name: str | None) -> str | None:
|
||||
"""Read the SOUL.md file for a custom agent, if it exists.
|
||||
|
||||
SOUL.md defines the agent's personality, values, and behavioral guardrails.
|
||||
It is injected into the lead agent's system prompt as additional context.
|
||||
|
||||
Args:
|
||||
agent_name: The name of the agent or None for the default agent.
|
||||
|
||||
Returns:
|
||||
The SOUL.md content as a string, or None if the file does not exist.
|
||||
"""
|
||||
agent_dir = get_paths().agent_dir(agent_name) if agent_name else get_paths().base_dir
|
||||
soul_path = agent_dir / SOUL_FILENAME
|
||||
if not soul_path.exists():
|
||||
return None
|
||||
content = soul_path.read_text(encoding="utf-8").strip()
|
||||
return content or None
|
||||
|
||||
|
||||
def list_custom_agents() -> list[AgentConfig]:
|
||||
"""Scan the agents directory and return all valid custom agents.
|
||||
|
||||
Returns:
|
||||
List of AgentConfig for each valid agent directory found.
|
||||
"""
|
||||
agents_dir = get_paths().agents_dir
|
||||
|
||||
if not agents_dir.exists():
|
||||
return []
|
||||
|
||||
agents: list[AgentConfig] = []
|
||||
|
||||
for entry in sorted(agents_dir.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
|
||||
config_file = entry / "config.yaml"
|
||||
if not config_file.exists():
|
||||
logger.debug(f"Skipping {entry.name}: no config.yaml")
|
||||
continue
|
||||
|
||||
try:
|
||||
agent_cfg = load_agent_config(entry.name)
|
||||
agents.append(agent_cfg)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping agent '{entry.name}': {e}")
|
||||
|
||||
return agents
|
||||
379
deer-flow/backend/packages/harness/deerflow/config/app_config.py
Normal file
379
deer-flow/backend/packages/harness/deerflow/config/app_config.py
Normal file
@@ -0,0 +1,379 @@
|
||||
import logging
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Any, Self
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from deerflow.config.acp_config import load_acp_config_from_dict
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
|
||||
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
|
||||
from deerflow.config.skills_config import SkillsConfig
|
||||
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict
|
||||
from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict
|
||||
from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict
|
||||
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
|
||||
from deerflow.config.token_usage_config import TokenUsageConfig
|
||||
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_config_candidates() -> tuple[Path, ...]:
|
||||
"""Return deterministic config.yaml locations without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
return (backend_dir / "config.yaml", repo_root / "config.yaml")
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""Config for the DeerFlow application"""
|
||||
|
||||
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
|
||||
token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration")
|
||||
models: list[ModelConfig] = Field(default_factory=list, description="Available models")
|
||||
sandbox: SandboxConfig = Field(description="Sandbox configuration")
|
||||
tools: list[ToolConfig] = Field(default_factory=list, description="Available tools")
|
||||
tool_groups: list[ToolGroupConfig] = Field(default_factory=list, description="Available tool groups")
|
||||
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
|
||||
skill_evolution: SkillEvolutionConfig = Field(default_factory=SkillEvolutionConfig, description="Agent-managed skill evolution configuration")
|
||||
extensions: ExtensionsConfig = Field(default_factory=ExtensionsConfig, description="Extensions configuration (MCP servers and skills state)")
|
||||
tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig, description="Tool search / deferred loading configuration")
|
||||
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
|
||||
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
|
||||
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
|
||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
model_config = ConfigDict(extra="allow", frozen=False)
|
||||
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
|
||||
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
|
||||
|
||||
@classmethod
|
||||
def resolve_config_path(cls, config_path: str | None = None) -> Path:
|
||||
"""Resolve the config file path.
|
||||
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
|
||||
"""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(f"Config file specified by param `config_path` not found at {path}")
|
||||
return path
|
||||
elif os.getenv("DEER_FLOW_CONFIG_PATH"):
|
||||
path = Path(os.getenv("DEER_FLOW_CONFIG_PATH"))
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
for path in _default_config_candidates():
|
||||
if path.exists():
|
||||
return path
|
||||
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str | None = None) -> Self:
|
||||
"""Load config from YAML file.
|
||||
|
||||
See `resolve_config_path` for more details.
|
||||
|
||||
Args:
|
||||
config_path: Path to the config file.
|
||||
|
||||
Returns:
|
||||
AppConfig: The loaded config.
|
||||
"""
|
||||
resolved_path = cls.resolve_config_path(config_path)
|
||||
with open(resolved_path, encoding="utf-8") as f:
|
||||
config_data = yaml.safe_load(f) or {}
|
||||
|
||||
# Check config version before processing
|
||||
cls._check_config_version(config_data, resolved_path)
|
||||
|
||||
config_data = cls.resolve_env_variables(config_data)
|
||||
|
||||
# Load title config if present
|
||||
if "title" in config_data:
|
||||
load_title_config_from_dict(config_data["title"])
|
||||
|
||||
# Load summarization config if present
|
||||
if "summarization" in config_data:
|
||||
load_summarization_config_from_dict(config_data["summarization"])
|
||||
|
||||
# Load memory config if present
|
||||
if "memory" in config_data:
|
||||
load_memory_config_from_dict(config_data["memory"])
|
||||
|
||||
# Load subagents config if present
|
||||
if "subagents" in config_data:
|
||||
load_subagents_config_from_dict(config_data["subagents"])
|
||||
|
||||
# Load tool_search config if present
|
||||
if "tool_search" in config_data:
|
||||
load_tool_search_config_from_dict(config_data["tool_search"])
|
||||
|
||||
# Load guardrails config if present
|
||||
if "guardrails" in config_data:
|
||||
load_guardrails_config_from_dict(config_data["guardrails"])
|
||||
|
||||
# Load checkpointer config if present
|
||||
if "checkpointer" in config_data:
|
||||
load_checkpointer_config_from_dict(config_data["checkpointer"])
|
||||
|
||||
# Load stream bridge config if present
|
||||
if "stream_bridge" in config_data:
|
||||
load_stream_bridge_config_from_dict(config_data["stream_bridge"])
|
||||
|
||||
# Always refresh ACP agent config so removed entries do not linger across reloads.
|
||||
load_acp_config_from_dict(config_data.get("acp_agents", {}))
|
||||
|
||||
# Load extensions config separately (it's in a different file)
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
config_data["extensions"] = extensions_config.model_dump()
|
||||
|
||||
result = cls.model_validate(config_data)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _check_config_version(cls, config_data: dict, config_path: Path) -> None:
|
||||
"""Check if the user's config.yaml is outdated compared to config.example.yaml.
|
||||
|
||||
Emits a warning if the user's config_version is lower than the example's.
|
||||
Missing config_version is treated as version 0 (pre-versioning).
|
||||
"""
|
||||
try:
|
||||
user_version = int(config_data.get("config_version", 0))
|
||||
except (TypeError, ValueError):
|
||||
user_version = 0
|
||||
|
||||
# Find config.example.yaml by searching config.yaml's directory and its parents
|
||||
example_path = None
|
||||
search_dir = config_path.parent
|
||||
for _ in range(5): # search up to 5 levels
|
||||
candidate = search_dir / "config.example.yaml"
|
||||
if candidate.exists():
|
||||
example_path = candidate
|
||||
break
|
||||
parent = search_dir.parent
|
||||
if parent == search_dir:
|
||||
break
|
||||
search_dir = parent
|
||||
if example_path is None:
|
||||
return
|
||||
|
||||
try:
|
||||
with open(example_path, encoding="utf-8") as f:
|
||||
example_data = yaml.safe_load(f)
|
||||
raw = example_data.get("config_version", 0) if example_data else 0
|
||||
try:
|
||||
example_version = int(raw)
|
||||
except (TypeError, ValueError):
|
||||
example_version = 0
|
||||
except Exception:
|
||||
return
|
||||
|
||||
if user_version < example_version:
|
||||
logger.warning(
|
||||
"Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to merge new fields into your config.",
|
||||
user_version,
|
||||
example_version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def resolve_env_variables(cls, config: Any) -> Any:
|
||||
"""Recursively resolve environment variables in the config.
|
||||
|
||||
Environment variables are resolved using the `os.getenv` function. Example: $OPENAI_API_KEY
|
||||
|
||||
Args:
|
||||
config: The config to resolve environment variables in.
|
||||
|
||||
Returns:
|
||||
The config with environment variables resolved.
|
||||
"""
|
||||
if isinstance(config, str):
|
||||
if config.startswith("$"):
|
||||
env_value = os.getenv(config[1:])
|
||||
if env_value is None:
|
||||
raise ValueError(f"Environment variable {config[1:]} not found for config value {config}")
|
||||
return env_value
|
||||
return config
|
||||
elif isinstance(config, dict):
|
||||
return {k: cls.resolve_env_variables(v) for k, v in config.items()}
|
||||
elif isinstance(config, list):
|
||||
return [cls.resolve_env_variables(item) for item in config]
|
||||
return config
|
||||
|
||||
def get_model_config(self, name: str) -> ModelConfig | None:
|
||||
"""Get the model config by name.
|
||||
|
||||
Args:
|
||||
name: The name of the model to get the config for.
|
||||
|
||||
Returns:
|
||||
The model config if found, otherwise None.
|
||||
"""
|
||||
return next((model for model in self.models if model.name == name), None)
|
||||
|
||||
def get_tool_config(self, name: str) -> ToolConfig | None:
|
||||
"""Get the tool config by name.
|
||||
|
||||
Args:
|
||||
name: The name of the tool to get the config for.
|
||||
|
||||
Returns:
|
||||
The tool config if found, otherwise None.
|
||||
"""
|
||||
return next((tool for tool in self.tools if tool.name == name), None)
|
||||
|
||||
def get_tool_group_config(self, name: str) -> ToolGroupConfig | None:
|
||||
"""Get the tool group config by name.
|
||||
|
||||
Args:
|
||||
name: The name of the tool group to get the config for.
|
||||
|
||||
Returns:
|
||||
The tool group config if found, otherwise None.
|
||||
"""
|
||||
return next((group for group in self.tool_groups if group.name == name), None)
|
||||
|
||||
|
||||
_app_config: AppConfig | None = None
|
||||
_app_config_path: Path | None = None
|
||||
_app_config_mtime: float | None = None
|
||||
_app_config_is_custom = False
|
||||
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
|
||||
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
|
||||
|
||||
|
||||
def _get_config_mtime(config_path: Path) -> float | None:
|
||||
"""Get the modification time of a config file if it exists."""
|
||||
try:
|
||||
return config_path.stat().st_mtime
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
|
||||
"""Load config from disk and refresh cache metadata."""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
|
||||
resolved_path = AppConfig.resolve_config_path(config_path)
|
||||
_app_config = AppConfig.from_file(str(resolved_path))
|
||||
_app_config_path = resolved_path
|
||||
_app_config_mtime = _get_config_mtime(resolved_path)
|
||||
_app_config_is_custom = False
|
||||
return _app_config
|
||||
|
||||
|
||||
def get_app_config() -> AppConfig:
|
||||
"""Get the DeerFlow config instance.
|
||||
|
||||
Returns a cached singleton instance and automatically reloads it when the
|
||||
underlying config file path or modification time changes. Use
|
||||
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
|
||||
the cache.
|
||||
"""
|
||||
global _app_config, _app_config_path, _app_config_mtime
|
||||
|
||||
runtime_override = _current_app_config.get()
|
||||
if runtime_override is not None:
|
||||
return runtime_override
|
||||
|
||||
if _app_config is not None and _app_config_is_custom:
|
||||
return _app_config
|
||||
|
||||
resolved_path = AppConfig.resolve_config_path()
|
||||
current_mtime = _get_config_mtime(resolved_path)
|
||||
|
||||
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
|
||||
if should_reload:
|
||||
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
|
||||
logger.info(
|
||||
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
|
||||
_app_config_mtime,
|
||||
current_mtime,
|
||||
)
|
||||
_load_and_cache_app_config(str(resolved_path))
|
||||
return _app_config
|
||||
|
||||
|
||||
def reload_app_config(config_path: str | None = None) -> AppConfig:
|
||||
"""Reload the config from file and update the cached instance.
|
||||
|
||||
This is useful when the config file has been modified and you want
|
||||
to pick up the changes without restarting the application.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to config file. If not provided,
|
||||
uses the default resolution strategy.
|
||||
|
||||
Returns:
|
||||
The newly loaded AppConfig instance.
|
||||
"""
|
||||
return _load_and_cache_app_config(config_path)
|
||||
|
||||
|
||||
def reset_app_config() -> None:
|
||||
"""Reset the cached config instance.
|
||||
|
||||
This clears the singleton cache, causing the next call to
|
||||
`get_app_config()` to reload from file. Useful for testing
|
||||
or when switching between different configurations.
|
||||
"""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
_app_config = None
|
||||
_app_config_path = None
|
||||
_app_config_mtime = None
|
||||
_app_config_is_custom = False
|
||||
|
||||
|
||||
def set_app_config(config: AppConfig) -> None:
|
||||
"""Set a custom config instance.
|
||||
|
||||
This allows injecting a custom or mock config for testing purposes.
|
||||
|
||||
Args:
|
||||
config: The AppConfig instance to use.
|
||||
"""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
_app_config = config
|
||||
_app_config_path = None
|
||||
_app_config_mtime = None
|
||||
_app_config_is_custom = True
|
||||
|
||||
|
||||
def peek_current_app_config() -> AppConfig | None:
|
||||
"""Return the runtime-scoped AppConfig override, if one is active."""
|
||||
return _current_app_config.get()
|
||||
|
||||
|
||||
def push_current_app_config(config: AppConfig) -> None:
|
||||
"""Push a runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
_current_app_config_stack.set(stack + (_current_app_config.get(),))
|
||||
_current_app_config.set(config)
|
||||
|
||||
|
||||
def pop_current_app_config() -> None:
|
||||
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
if not stack:
|
||||
_current_app_config.set(None)
|
||||
return
|
||||
previous = stack[-1]
|
||||
_current_app_config_stack.set(stack[:-1])
|
||||
_current_app_config.set(previous)
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Configuration for LangGraph checkpointer."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
CheckpointerType = Literal["memory", "sqlite", "postgres"]
|
||||
|
||||
|
||||
class CheckpointerConfig(BaseModel):
|
||||
"""Configuration for LangGraph state persistence checkpointer."""
|
||||
|
||||
type: CheckpointerType = Field(
|
||||
description="Checkpointer backend type. "
|
||||
"'memory' is in-process only (lost on restart). "
|
||||
"'sqlite' persists to a local file (requires langgraph-checkpoint-sqlite). "
|
||||
"'postgres' persists to PostgreSQL (requires langgraph-checkpoint-postgres)."
|
||||
)
|
||||
connection_string: str | None = Field(
|
||||
default=None,
|
||||
description="Connection string for sqlite (file path) or postgres (DSN). "
|
||||
"Required for sqlite and postgres types. "
|
||||
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
|
||||
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance — None means no checkpointer is configured.
|
||||
_checkpointer_config: CheckpointerConfig | None = None
|
||||
|
||||
|
||||
def get_checkpointer_config() -> CheckpointerConfig | None:
|
||||
"""Get the current checkpointer configuration, or None if not configured."""
|
||||
return _checkpointer_config
|
||||
|
||||
|
||||
def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
|
||||
"""Set the checkpointer configuration."""
|
||||
global _checkpointer_config
|
||||
_checkpointer_config = config
|
||||
|
||||
|
||||
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
|
||||
"""Load checkpointer configuration from a dictionary."""
|
||||
global _checkpointer_config
|
||||
_checkpointer_config = CheckpointerConfig(**config_dict)
|
||||
@@ -0,0 +1,256 @@
|
||||
"""Unified extensions configuration for MCP servers and skills."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class McpOAuthConfig(BaseModel):
|
||||
"""OAuth configuration for an MCP server (HTTP/SSE transports)."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
|
||||
token_url: str = Field(description="OAuth token endpoint URL")
|
||||
grant_type: Literal["client_credentials", "refresh_token"] = Field(
|
||||
default="client_credentials",
|
||||
description="OAuth grant type",
|
||||
)
|
||||
client_id: str | None = Field(default=None, description="OAuth client ID")
|
||||
client_secret: str | None = Field(default=None, description="OAuth client secret")
|
||||
refresh_token: str | None = Field(default=None, description="OAuth refresh token (for refresh_token grant)")
|
||||
scope: str | None = Field(default=None, description="OAuth scope")
|
||||
audience: str | None = Field(default=None, description="OAuth audience (provider-specific)")
|
||||
token_field: str = Field(default="access_token", description="Field name containing access token in token response")
|
||||
token_type_field: str = Field(default="token_type", description="Field name containing token type in token response")
|
||||
expires_in_field: str = Field(default="expires_in", description="Field name containing expiry (seconds) in token response")
|
||||
default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response")
|
||||
refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry")
|
||||
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class McpServerConfig(BaseModel):
|
||||
"""Configuration for a single MCP server."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
|
||||
type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'")
|
||||
command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)")
|
||||
args: list[str] = Field(default_factory=list, description="Arguments to pass to the command (for stdio type)")
|
||||
env: dict[str, str] = Field(default_factory=dict, description="Environment variables for the MCP server")
|
||||
url: str | None = Field(default=None, description="URL of the MCP server (for sse or http type)")
|
||||
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
|
||||
oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)")
|
||||
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class SkillStateConfig(BaseModel):
|
||||
"""Configuration for a single skill's state."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Whether this skill is enabled")
|
||||
|
||||
|
||||
class ExtensionsConfig(BaseModel):
|
||||
"""Unified configuration for MCP servers and skills."""
|
||||
|
||||
mcp_servers: dict[str, McpServerConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of MCP server name to configuration",
|
||||
alias="mcpServers",
|
||||
)
|
||||
skills: dict[str, SkillStateConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of skill name to state configuration",
|
||||
)
|
||||
model_config = ConfigDict(extra="allow", populate_by_name=True)
|
||||
|
||||
@classmethod
|
||||
def resolve_config_path(cls, config_path: str | None = None) -> Path | None:
|
||||
"""Resolve the extensions config file path.
|
||||
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, check for `extensions_config.json` in the current directory, then in the parent directory.
|
||||
4. For backward compatibility, also check for `mcp_config.json` if `extensions_config.json` is not found.
|
||||
5. If not found, return None (extensions are optional).
|
||||
|
||||
Args:
|
||||
config_path: Optional path to extensions config file.
|
||||
|
||||
Resolution order:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, search backend/repository-root defaults for
|
||||
`extensions_config.json`, then legacy `mcp_config.json`.
|
||||
|
||||
Returns:
|
||||
Path to the extensions config file if found, otherwise None.
|
||||
"""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Extensions config file specified by param `config_path` not found at {path}")
|
||||
return path
|
||||
elif os.getenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH"):
|
||||
path = Path(os.getenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH"))
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Extensions config file specified by environment variable `DEER_FLOW_EXTENSIONS_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
for path in (
|
||||
backend_dir / "extensions_config.json",
|
||||
repo_root / "extensions_config.json",
|
||||
backend_dir / "mcp_config.json",
|
||||
repo_root / "mcp_config.json",
|
||||
):
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
# Extensions are optional, so return None if not found
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str | None = None) -> "ExtensionsConfig":
|
||||
"""Load extensions config from JSON file.
|
||||
|
||||
See `resolve_config_path` for more details.
|
||||
|
||||
Args:
|
||||
config_path: Path to the extensions config file.
|
||||
|
||||
Returns:
|
||||
ExtensionsConfig: The loaded config, or empty config if file not found.
|
||||
"""
|
||||
resolved_path = cls.resolve_config_path(config_path)
|
||||
if resolved_path is None:
|
||||
# Return empty config if extensions config file is not found
|
||||
return cls(mcp_servers={}, skills={})
|
||||
|
||||
try:
|
||||
with open(resolved_path, encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
cls.resolve_env_variables(config_data)
|
||||
return cls.model_validate(config_data)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Extensions config file at {resolved_path} is not valid JSON: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load extensions config from {resolved_path}: {e}") from e
|
||||
|
||||
@classmethod
|
||||
def resolve_env_variables(cls, config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively resolve environment variables in the config.
|
||||
|
||||
Environment variables are resolved using the `os.getenv` function. Example: $OPENAI_API_KEY
|
||||
|
||||
Args:
|
||||
config: The config to resolve environment variables in.
|
||||
|
||||
Returns:
|
||||
The config with environment variables resolved.
|
||||
"""
|
||||
for key, value in config.items():
|
||||
if isinstance(value, str):
|
||||
if value.startswith("$"):
|
||||
env_value = os.getenv(value[1:])
|
||||
if env_value is None:
|
||||
# Unresolved placeholder — store empty string so downstream
|
||||
# consumers (e.g. MCP servers) don't receive the literal "$VAR"
|
||||
# token as an actual environment value.
|
||||
config[key] = ""
|
||||
else:
|
||||
config[key] = env_value
|
||||
else:
|
||||
config[key] = value
|
||||
elif isinstance(value, dict):
|
||||
config[key] = cls.resolve_env_variables(value)
|
||||
elif isinstance(value, list):
|
||||
config[key] = [cls.resolve_env_variables(item) if isinstance(item, dict) else item for item in value]
|
||||
return config
|
||||
|
||||
def get_enabled_mcp_servers(self) -> dict[str, McpServerConfig]:
|
||||
"""Get only the enabled MCP servers.
|
||||
|
||||
Returns:
|
||||
Dictionary of enabled MCP servers.
|
||||
"""
|
||||
return {name: config for name, config in self.mcp_servers.items() if config.enabled}
|
||||
|
||||
def is_skill_enabled(self, skill_name: str, skill_category: str) -> bool:
|
||||
"""Check if a skill is enabled.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
skill_category: Category of the skill
|
||||
|
||||
Returns:
|
||||
True if enabled, False otherwise
|
||||
"""
|
||||
skill_config = self.skills.get(skill_name)
|
||||
if skill_config is None:
|
||||
# Default to enable for public & custom skill
|
||||
return skill_category in ("public", "custom")
|
||||
return skill_config.enabled
|
||||
|
||||
|
||||
_extensions_config: ExtensionsConfig | None = None
|
||||
|
||||
|
||||
def get_extensions_config() -> ExtensionsConfig:
|
||||
"""Get the extensions config instance.
|
||||
|
||||
Returns a cached singleton instance. Use `reload_extensions_config()` to reload
|
||||
from file, or `reset_extensions_config()` to clear the cache.
|
||||
|
||||
Returns:
|
||||
The cached ExtensionsConfig instance.
|
||||
"""
|
||||
global _extensions_config
|
||||
if _extensions_config is None:
|
||||
_extensions_config = ExtensionsConfig.from_file()
|
||||
return _extensions_config
|
||||
|
||||
|
||||
def reload_extensions_config(config_path: str | None = None) -> ExtensionsConfig:
|
||||
"""Reload the extensions config from file and update the cached instance.
|
||||
|
||||
This is useful when the config file has been modified and you want
|
||||
to pick up the changes without restarting the application.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to extensions config file. If not provided,
|
||||
uses the default resolution strategy.
|
||||
|
||||
Returns:
|
||||
The newly loaded ExtensionsConfig instance.
|
||||
"""
|
||||
global _extensions_config
|
||||
_extensions_config = ExtensionsConfig.from_file(config_path)
|
||||
return _extensions_config
|
||||
|
||||
|
||||
def reset_extensions_config() -> None:
|
||||
"""Reset the cached extensions config instance.
|
||||
|
||||
This clears the singleton cache, causing the next call to
|
||||
`get_extensions_config()` to reload from file. Useful for testing
|
||||
or when switching between different configurations.
|
||||
"""
|
||||
global _extensions_config
|
||||
_extensions_config = None
|
||||
|
||||
|
||||
def set_extensions_config(config: ExtensionsConfig) -> None:
|
||||
"""Set a custom extensions config instance.
|
||||
|
||||
This allows injecting a custom or mock config for testing purposes.
|
||||
|
||||
Args:
|
||||
config: The ExtensionsConfig instance to use.
|
||||
"""
|
||||
global _extensions_config
|
||||
_extensions_config = config
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Configuration for pre-tool-call authorization."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GuardrailProviderConfig(BaseModel):
|
||||
"""Configuration for a guardrail provider."""
|
||||
|
||||
use: str = Field(description="Class path (e.g. 'deerflow.guardrails.builtin:AllowlistProvider')")
|
||||
config: dict = Field(default_factory=dict, description="Provider-specific settings passed as kwargs")
|
||||
|
||||
|
||||
class GuardrailsConfig(BaseModel):
|
||||
"""Configuration for pre-tool-call authorization.
|
||||
|
||||
When enabled, every tool call passes through the configured provider
|
||||
before execution. The provider receives tool name, arguments, and the
|
||||
agent's passport reference, and returns an allow/deny decision.
|
||||
"""
|
||||
|
||||
enabled: bool = Field(default=False, description="Enable guardrail middleware")
|
||||
fail_closed: bool = Field(default=True, description="Block tool calls if provider errors")
|
||||
passport: str | None = Field(default=None, description="OAP passport path or hosted agent ID")
|
||||
provider: GuardrailProviderConfig | None = Field(default=None, description="Guardrail provider configuration")
|
||||
|
||||
|
||||
_guardrails_config: GuardrailsConfig | None = None
|
||||
|
||||
|
||||
def get_guardrails_config() -> GuardrailsConfig:
|
||||
"""Get the guardrails config, returning defaults if not loaded."""
|
||||
global _guardrails_config
|
||||
if _guardrails_config is None:
|
||||
_guardrails_config = GuardrailsConfig()
|
||||
return _guardrails_config
|
||||
|
||||
|
||||
def load_guardrails_config_from_dict(data: dict) -> GuardrailsConfig:
|
||||
"""Load guardrails config from a dict (called during AppConfig loading)."""
|
||||
global _guardrails_config
|
||||
_guardrails_config = GuardrailsConfig.model_validate(data)
|
||||
return _guardrails_config
|
||||
|
||||
|
||||
def reset_guardrails_config() -> None:
|
||||
"""Reset the cached config instance. Used in tests to prevent singleton leaks."""
|
||||
global _guardrails_config
|
||||
_guardrails_config = None
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Configuration for memory mechanism."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""Configuration for global memory mechanism."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable memory mechanism",
|
||||
)
|
||||
storage_path: str = Field(
|
||||
default="",
|
||||
description=(
|
||||
"Path to store memory data. "
|
||||
"If empty, defaults to `{base_dir}/memory.json` (see Paths.memory_file). "
|
||||
"Absolute paths are used as-is. "
|
||||
"Relative paths are resolved against `Paths.base_dir` "
|
||||
"(not the backend working directory). "
|
||||
"Note: if you previously set this to `.deer-flow/memory.json`, "
|
||||
"the file will now be resolved as `{base_dir}/.deer-flow/memory.json`; "
|
||||
"migrate existing data or use an absolute path to preserve the old location."
|
||||
),
|
||||
)
|
||||
storage_class: str = Field(
|
||||
default="deerflow.agents.memory.storage.FileMemoryStorage",
|
||||
description="The class path for memory storage provider",
|
||||
)
|
||||
debounce_seconds: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=300,
|
||||
description="Seconds to wait before processing queued updates (debounce)",
|
||||
)
|
||||
model_name: str | None = Field(
|
||||
default=None,
|
||||
description="Model name to use for memory updates (None = use default model)",
|
||||
)
|
||||
max_facts: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=500,
|
||||
description="Maximum number of facts to store",
|
||||
)
|
||||
fact_confidence_threshold: float = Field(
|
||||
default=0.7,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum confidence threshold for storing facts",
|
||||
)
|
||||
injection_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to inject memory into system prompt",
|
||||
)
|
||||
max_injection_tokens: int = Field(
|
||||
default=2000,
|
||||
ge=100,
|
||||
le=8000,
|
||||
description="Maximum tokens to use for memory injection",
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
_memory_config: MemoryConfig = MemoryConfig()
|
||||
|
||||
|
||||
def get_memory_config() -> MemoryConfig:
|
||||
"""Get the current memory configuration."""
|
||||
return _memory_config
|
||||
|
||||
|
||||
def set_memory_config(config: MemoryConfig) -> None:
|
||||
"""Set the memory configuration."""
|
||||
global _memory_config
|
||||
_memory_config = config
|
||||
|
||||
|
||||
def load_memory_config_from_dict(config_dict: dict) -> None:
|
||||
"""Load memory configuration from a dictionary."""
|
||||
global _memory_config
|
||||
_memory_config = MemoryConfig(**config_dict)
|
||||
@@ -0,0 +1,41 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""Config section for a model"""
|
||||
|
||||
name: str = Field(..., description="Unique name for the model")
|
||||
display_name: str | None = Field(..., default_factory=lambda: None, description="Display name for the model")
|
||||
description: str | None = Field(..., default_factory=lambda: None, description="Description for the model")
|
||||
use: str = Field(
|
||||
...,
|
||||
description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)",
|
||||
)
|
||||
model: str = Field(..., description="Model name")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
use_responses_api: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to route OpenAI ChatOpenAI calls through the /v1/responses API",
|
||||
)
|
||||
output_version: str | None = Field(
|
||||
default=None,
|
||||
description="Structured output version for OpenAI responses content, e.g. responses/v1",
|
||||
)
|
||||
supports_thinking: bool = Field(default_factory=lambda: False, description="Whether the model supports thinking")
|
||||
supports_reasoning_effort: bool = Field(default_factory=lambda: False, description="Whether the model supports reasoning effort")
|
||||
when_thinking_enabled: dict | None = Field(
|
||||
default_factory=lambda: None,
|
||||
description="Extra settings to be passed to the model when thinking is enabled",
|
||||
)
|
||||
when_thinking_disabled: dict | None = Field(
|
||||
default_factory=lambda: None,
|
||||
description="Extra settings to be passed to the model when thinking is disabled",
|
||||
)
|
||||
supports_vision: bool = Field(default_factory=lambda: False, description="Whether the model supports vision/image inputs")
|
||||
thinking: dict | None = Field(
|
||||
default_factory=lambda: None,
|
||||
description=(
|
||||
"Thinking settings for the model. If provided, these settings will be passed to the model when thinking is enabled. "
|
||||
"This is a shortcut for `when_thinking_enabled` and will be merged with `when_thinking_enabled` if both are provided."
|
||||
),
|
||||
)
|
||||
306
deer-flow/backend/packages/harness/deerflow/config/paths.py
Normal file
306
deer-flow/backend/packages/harness/deerflow/config/paths.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path, PureWindowsPath
|
||||
|
||||
# Virtual path prefix seen by agents inside the sandbox
|
||||
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
||||
|
||||
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
|
||||
|
||||
def _default_local_base_dir() -> Path:
|
||||
"""Return the repo-local DeerFlow state directory without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
return backend_dir / ".deer-flow"
|
||||
|
||||
|
||||
def _validate_thread_id(thread_id: str) -> str:
|
||||
"""Validate a thread ID before using it in filesystem paths."""
|
||||
if not _SAFE_THREAD_ID_RE.match(thread_id):
|
||||
raise ValueError(f"Invalid thread_id {thread_id!r}: only alphanumeric characters, hyphens, and underscores are allowed.")
|
||||
return thread_id
|
||||
|
||||
|
||||
def _join_host_path(base: str, *parts: str) -> str:
|
||||
"""Join host filesystem path segments while preserving native style.
|
||||
|
||||
Docker Desktop on Windows expects bind mount sources to stay in Windows
|
||||
path form (for example ``C:\\repo\\backend\\.deer-flow``). Using
|
||||
``Path(base) / ...`` on a POSIX host can accidentally rewrite those paths
|
||||
with mixed separators, so this helper preserves the original style.
|
||||
"""
|
||||
if not parts:
|
||||
return base
|
||||
|
||||
if re.match(r"^[A-Za-z]:[\\/]", base) or base.startswith("\\\\") or "\\" in base:
|
||||
result = PureWindowsPath(base)
|
||||
for part in parts:
|
||||
result /= part
|
||||
return str(result)
|
||||
|
||||
result = Path(base)
|
||||
for part in parts:
|
||||
result /= part
|
||||
return str(result)
|
||||
|
||||
|
||||
def join_host_path(base: str, *parts: str) -> str:
|
||||
"""Join host filesystem path segments while preserving native style."""
|
||||
return _join_host_path(base, *parts)
|
||||
|
||||
|
||||
class Paths:
|
||||
"""
|
||||
Centralized path configuration for DeerFlow application data.
|
||||
|
||||
Directory layout (host side):
|
||||
{base_dir}/
|
||||
├── memory.json
|
||||
├── USER.md <-- global user profile (injected into all agents)
|
||||
├── agents/
|
||||
│ └── {agent_name}/
|
||||
│ ├── config.yaml
|
||||
│ ├── SOUL.md <-- agent personality/identity (injected alongside lead prompt)
|
||||
│ └── memory.json
|
||||
└── threads/
|
||||
└── {thread_id}/
|
||||
└── user-data/ <-- mounted as /mnt/user-data/ inside sandbox
|
||||
├── workspace/ <-- /mnt/user-data/workspace/
|
||||
├── uploads/ <-- /mnt/user-data/uploads/
|
||||
└── outputs/ <-- /mnt/user-data/outputs/
|
||||
|
||||
BaseDir resolution (in priority order):
|
||||
1. Constructor argument `base_dir`
|
||||
2. DEER_FLOW_HOME environment variable
|
||||
3. Repo-local fallback derived from this module path: `{backend_dir}/.deer-flow`
|
||||
"""
|
||||
|
||||
def __init__(self, base_dir: str | Path | None = None) -> None:
|
||||
self._base_dir = Path(base_dir).resolve() if base_dir is not None else None
|
||||
|
||||
@property
|
||||
def host_base_dir(self) -> Path:
|
||||
"""Host-visible base dir for Docker volume mount sources.
|
||||
|
||||
When running inside Docker with a mounted Docker socket (DooD), the Docker
|
||||
daemon runs on the host and resolves mount paths against the host filesystem.
|
||||
Set DEER_FLOW_HOST_BASE_DIR to the host-side path that corresponds to this
|
||||
container's base_dir so that sandbox container volume mounts work correctly.
|
||||
|
||||
Falls back to base_dir when the env var is not set (native/local execution).
|
||||
"""
|
||||
if env := os.getenv("DEER_FLOW_HOST_BASE_DIR"):
|
||||
return Path(env)
|
||||
return self.base_dir
|
||||
|
||||
def _host_base_dir_str(self) -> str:
|
||||
"""Return the host base dir as a raw string for bind mounts."""
|
||||
if env := os.getenv("DEER_FLOW_HOST_BASE_DIR"):
|
||||
return env
|
||||
return str(self.base_dir)
|
||||
|
||||
@property
|
||||
def base_dir(self) -> Path:
|
||||
"""Root directory for all application data."""
|
||||
if self._base_dir is not None:
|
||||
return self._base_dir
|
||||
|
||||
if env_home := os.getenv("DEER_FLOW_HOME"):
|
||||
return Path(env_home).resolve()
|
||||
|
||||
return _default_local_base_dir()
|
||||
|
||||
@property
|
||||
def memory_file(self) -> Path:
|
||||
"""Path to the persisted memory file: `{base_dir}/memory.json`."""
|
||||
return self.base_dir / "memory.json"
|
||||
|
||||
@property
|
||||
def user_md_file(self) -> Path:
|
||||
"""Path to the global user profile file: `{base_dir}/USER.md`."""
|
||||
return self.base_dir / "USER.md"
|
||||
|
||||
@property
|
||||
def agents_dir(self) -> Path:
|
||||
"""Root directory for all custom agents: `{base_dir}/agents/`."""
|
||||
return self.base_dir / "agents"
|
||||
|
||||
def agent_dir(self, name: str) -> Path:
|
||||
"""Directory for a specific agent: `{base_dir}/agents/{name}/`."""
|
||||
return self.agents_dir / name.lower()
|
||||
|
||||
def agent_memory_file(self, name: str) -> Path:
|
||||
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
|
||||
return self.agent_dir(name) / "memory.json"
|
||||
|
||||
def thread_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for a thread's data: `{base_dir}/threads/{thread_id}/`
|
||||
|
||||
This directory contains a `user-data/` subdirectory that is mounted
|
||||
as `/mnt/user-data/` inside the sandbox.
|
||||
|
||||
Raises:
|
||||
ValueError: If `thread_id` contains unsafe characters (path separators
|
||||
or `..`) that could cause directory traversal.
|
||||
"""
|
||||
return self.base_dir / "threads" / _validate_thread_id(thread_id)
|
||||
|
||||
def sandbox_work_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for the agent's workspace directory.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/workspace/`
|
||||
Sandbox: `/mnt/user-data/workspace/`
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "user-data" / "workspace"
|
||||
|
||||
def sandbox_uploads_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for user-uploaded files.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/uploads/`
|
||||
Sandbox: `/mnt/user-data/uploads/`
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "user-data" / "uploads"
|
||||
|
||||
def sandbox_outputs_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for agent-generated artifacts.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/outputs/`
|
||||
Sandbox: `/mnt/user-data/outputs/`
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "user-data" / "outputs"
|
||||
|
||||
def acp_workspace_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for the ACP workspace of a specific thread.
|
||||
Host: `{base_dir}/threads/{thread_id}/acp-workspace/`
|
||||
Sandbox: `/mnt/acp-workspace/`
|
||||
|
||||
Each thread gets its own isolated ACP workspace so that concurrent
|
||||
sessions cannot read each other's ACP agent outputs.
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "acp-workspace"
|
||||
|
||||
def sandbox_user_data_dir(self, thread_id: str) -> Path:
|
||||
"""
|
||||
Host path for the user-data root.
|
||||
Host: `{base_dir}/threads/{thread_id}/user-data/`
|
||||
Sandbox: `/mnt/user-data/`
|
||||
"""
|
||||
return self.thread_dir(thread_id) / "user-data"
|
||||
|
||||
def host_thread_dir(self, thread_id: str) -> str:
|
||||
"""Host path for a thread directory, preserving Windows path syntax."""
|
||||
return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id))
|
||||
|
||||
def host_sandbox_user_data_dir(self, thread_id: str) -> str:
|
||||
"""Host path for a thread's user-data root."""
|
||||
return _join_host_path(self.host_thread_dir(thread_id), "user-data")
|
||||
|
||||
def host_sandbox_work_dir(self, thread_id: str) -> str:
|
||||
"""Host path for the workspace mount source."""
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "workspace")
|
||||
|
||||
def host_sandbox_uploads_dir(self, thread_id: str) -> str:
|
||||
"""Host path for the uploads mount source."""
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "uploads")
|
||||
|
||||
def host_sandbox_outputs_dir(self, thread_id: str) -> str:
|
||||
"""Host path for the outputs mount source."""
|
||||
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "outputs")
|
||||
|
||||
def host_acp_workspace_dir(self, thread_id: str) -> str:
|
||||
"""Host path for the ACP workspace mount source."""
|
||||
return _join_host_path(self.host_thread_dir(thread_id), "acp-workspace")
|
||||
|
||||
def ensure_thread_dirs(self, thread_id: str) -> None:
|
||||
"""Create all standard sandbox directories for a thread.
|
||||
|
||||
Directories are created with mode 0o777 so that sandbox containers
|
||||
(which may run as a different UID than the host backend process) can
|
||||
write to the volume-mounted paths without "Permission denied" errors.
|
||||
The explicit chmod() call is necessary because Path.mkdir(mode=...) is
|
||||
subject to the process umask and may not yield the intended permissions.
|
||||
|
||||
Includes the ACP workspace directory so it can be volume-mounted into
|
||||
the sandbox container at ``/mnt/acp-workspace`` even before the first
|
||||
ACP agent invocation.
|
||||
"""
|
||||
for d in [
|
||||
self.sandbox_work_dir(thread_id),
|
||||
self.sandbox_uploads_dir(thread_id),
|
||||
self.sandbox_outputs_dir(thread_id),
|
||||
self.acp_workspace_dir(thread_id),
|
||||
]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
d.chmod(0o777)
|
||||
|
||||
def delete_thread_dir(self, thread_id: str) -> None:
|
||||
"""Delete all persisted data for a thread.
|
||||
|
||||
The operation is idempotent: missing thread directories are ignored.
|
||||
"""
|
||||
thread_dir = self.thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
shutil.rmtree(thread_dir)
|
||||
|
||||
def resolve_virtual_path(self, thread_id: str, virtual_path: str) -> Path:
|
||||
"""Resolve a sandbox virtual path to the actual host filesystem path.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
virtual_path: Virtual path as seen inside the sandbox, e.g.
|
||||
``/mnt/user-data/outputs/report.pdf``.
|
||||
Leading slashes are stripped before matching.
|
||||
|
||||
Returns:
|
||||
The resolved absolute host filesystem path.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path does not start with the expected virtual
|
||||
prefix or a path-traversal attempt is detected.
|
||||
"""
|
||||
stripped = virtual_path.lstrip("/")
|
||||
prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
|
||||
|
||||
# Require an exact segment-boundary match to avoid prefix confusion
|
||||
# (e.g. reject paths like "mnt/user-dataX/...").
|
||||
if stripped != prefix and not stripped.startswith(prefix + "/"):
|
||||
raise ValueError(f"Path must start with /{prefix}")
|
||||
|
||||
relative = stripped[len(prefix) :].lstrip("/")
|
||||
base = self.sandbox_user_data_dir(thread_id).resolve()
|
||||
actual = (base / relative).resolve()
|
||||
|
||||
try:
|
||||
actual.relative_to(base)
|
||||
except ValueError:
|
||||
raise ValueError("Access denied: path traversal detected")
|
||||
|
||||
return actual
|
||||
|
||||
|
||||
# ── Singleton ────────────────────────────────────────────────────────────
|
||||
|
||||
_paths: Paths | None = None
|
||||
|
||||
|
||||
def get_paths() -> Paths:
|
||||
"""Return the global Paths singleton (lazy-initialized)."""
|
||||
global _paths
|
||||
if _paths is None:
|
||||
_paths = Paths()
|
||||
return _paths
|
||||
|
||||
|
||||
def resolve_path(path: str) -> Path:
|
||||
"""Resolve *path* to an absolute ``Path``.
|
||||
|
||||
Relative paths are resolved relative to the application base directory.
|
||||
Absolute paths are returned as-is (after normalisation).
|
||||
"""
|
||||
p = Path(path)
|
||||
if not p.is_absolute():
|
||||
p = get_paths().base_dir / path
|
||||
return p.resolve()
|
||||
@@ -0,0 +1,83 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class VolumeMountConfig(BaseModel):
|
||||
"""Configuration for a volume mount."""
|
||||
|
||||
host_path: str = Field(..., description="Path on the host machine")
|
||||
container_path: str = Field(..., description="Path inside the container")
|
||||
read_only: bool = Field(default=False, description="Whether the mount is read-only")
|
||||
|
||||
|
||||
class SandboxConfig(BaseModel):
|
||||
"""Config section for a sandbox.
|
||||
|
||||
Common options:
|
||||
use: Class path of the sandbox provider (required)
|
||||
allow_host_bash: Enable host-side bash execution for LocalSandboxProvider.
|
||||
Dangerous and intended only for fully trusted local workflows.
|
||||
|
||||
AioSandboxProvider specific options:
|
||||
image: Docker image to use (default: enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest)
|
||||
port: Base port for sandbox containers (default: 8080)
|
||||
replicas: Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.
|
||||
container_prefix: Prefix for container names (default: deer-flow-sandbox)
|
||||
idle_timeout: Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.
|
||||
mounts: List of volume mounts to share directories with the container
|
||||
environment: Environment variables to inject into the container (values starting with $ are resolved from host env)
|
||||
"""
|
||||
|
||||
use: str = Field(
|
||||
...,
|
||||
description="Class path of the sandbox provider (e.g. deerflow.sandbox.local:LocalSandboxProvider)",
|
||||
)
|
||||
allow_host_bash: bool = Field(
|
||||
default=False,
|
||||
description="Allow the bash tool to execute directly on the host when using LocalSandboxProvider. Dangerous; intended only for fully trusted local environments.",
|
||||
)
|
||||
image: str | None = Field(
|
||||
default=None,
|
||||
description="Docker image to use for the sandbox container",
|
||||
)
|
||||
port: int | None = Field(
|
||||
default=None,
|
||||
description="Base port for sandbox containers",
|
||||
)
|
||||
replicas: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of concurrent sandbox containers (default: 3). When the limit is reached the least-recently-used sandbox is evicted to make room.",
|
||||
)
|
||||
container_prefix: str | None = Field(
|
||||
default=None,
|
||||
description="Prefix for container names",
|
||||
)
|
||||
idle_timeout: int | None = Field(
|
||||
default=None,
|
||||
description="Idle timeout in seconds before sandbox is released (default: 600 = 10 minutes). Set to 0 to disable.",
|
||||
)
|
||||
mounts: list[VolumeMountConfig] = Field(
|
||||
default_factory=list,
|
||||
description="List of volume mounts to share directories between host and container",
|
||||
)
|
||||
environment: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Environment variables to inject into the sandbox container. Values starting with $ will be resolved from host environment variables.",
|
||||
)
|
||||
|
||||
bash_output_max_chars: int = Field(
|
||||
default=20000,
|
||||
ge=0,
|
||||
description="Maximum characters to keep from bash tool output. Output exceeding this limit is middle-truncated (head + tail), preserving the first and last half. Set to 0 to disable truncation.",
|
||||
)
|
||||
read_file_output_max_chars: int = Field(
|
||||
default=50000,
|
||||
ge=0,
|
||||
description="Maximum characters to keep from read_file tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
|
||||
)
|
||||
ls_output_max_chars: int = Field(
|
||||
default=20000,
|
||||
ge=0,
|
||||
description="Maximum characters to keep from ls tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
@@ -0,0 +1,14 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SkillEvolutionConfig(BaseModel):
|
||||
"""Configuration for agent-managed skill evolution."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether the agent can create and modify skills under skills/custom.",
|
||||
)
|
||||
moderation_model_name: str | None = Field(
|
||||
default=None,
|
||||
description="Optional model name for skill security moderation. Defaults to the primary chat model.",
|
||||
)
|
||||
@@ -0,0 +1,54 @@
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def _default_repo_root() -> Path:
|
||||
"""Resolve the repo root without relying on the current working directory."""
|
||||
return Path(__file__).resolve().parents[5]
|
||||
|
||||
|
||||
class SkillsConfig(BaseModel):
|
||||
"""Configuration for skills system"""
|
||||
|
||||
path: str | None = Field(
|
||||
default=None,
|
||||
description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory",
|
||||
)
|
||||
container_path: str = Field(
|
||||
default="/mnt/skills",
|
||||
description="Path where skills are mounted in the sandbox container",
|
||||
)
|
||||
|
||||
def get_skills_path(self) -> Path:
|
||||
"""
|
||||
Get the resolved skills directory path.
|
||||
|
||||
Returns:
|
||||
Path to the skills directory
|
||||
"""
|
||||
if self.path:
|
||||
# Use configured path (can be absolute or relative)
|
||||
path = Path(self.path)
|
||||
if not path.is_absolute():
|
||||
# If relative, resolve from the repo root for deterministic behavior.
|
||||
path = _default_repo_root() / path
|
||||
return path.resolve()
|
||||
else:
|
||||
# Default: ../skills relative to backend directory
|
||||
from deerflow.skills.loader import get_skills_root_path
|
||||
|
||||
return get_skills_root_path()
|
||||
|
||||
def get_skill_container_path(self, skill_name: str, category: str = "public") -> str:
|
||||
"""
|
||||
Get the full container path for a specific skill.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill (directory name)
|
||||
category: Category of the skill (public or custom)
|
||||
|
||||
Returns:
|
||||
Full path to the skill in the container
|
||||
"""
|
||||
return f"{self.container_path}/{category}/{skill_name}"
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Configuration for stream bridge."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
StreamBridgeType = Literal["memory", "redis"]
|
||||
|
||||
|
||||
class StreamBridgeConfig(BaseModel):
|
||||
"""Configuration for the stream bridge that connects agent workers to SSE endpoints."""
|
||||
|
||||
type: StreamBridgeType = Field(
|
||||
default="memory",
|
||||
description="Stream bridge backend type. 'memory' uses in-process asyncio.Queue (single-process only). 'redis' uses Redis Streams (planned for Phase 2, not yet implemented).",
|
||||
)
|
||||
redis_url: str | None = Field(
|
||||
default=None,
|
||||
description="Redis URL for the redis stream bridge type. Example: 'redis://localhost:6379/0'.",
|
||||
)
|
||||
queue_maxsize: int = Field(
|
||||
default=256,
|
||||
description="Maximum number of events buffered per run in the memory bridge.",
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance — None means no stream bridge is configured
|
||||
# (falls back to memory with defaults).
|
||||
_stream_bridge_config: StreamBridgeConfig | None = None
|
||||
|
||||
|
||||
def get_stream_bridge_config() -> StreamBridgeConfig | None:
|
||||
"""Get the current stream bridge configuration, or None if not configured."""
|
||||
return _stream_bridge_config
|
||||
|
||||
|
||||
def set_stream_bridge_config(config: StreamBridgeConfig | None) -> None:
|
||||
"""Set the stream bridge configuration."""
|
||||
global _stream_bridge_config
|
||||
_stream_bridge_config = config
|
||||
|
||||
|
||||
def load_stream_bridge_config_from_dict(config_dict: dict) -> None:
|
||||
"""Load stream bridge configuration from a dictionary."""
|
||||
global _stream_bridge_config
|
||||
_stream_bridge_config = StreamBridgeConfig(**config_dict)
|
||||
@@ -0,0 +1,102 @@
|
||||
"""Configuration for the subagent system loaded from config.yaml."""
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubagentOverrideConfig(BaseModel):
|
||||
"""Per-agent configuration overrides."""
|
||||
|
||||
timeout_seconds: int | None = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
description="Timeout in seconds for this subagent (None = use global default)",
|
||||
)
|
||||
max_turns: int | None = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
description="Maximum turns for this subagent (None = use global or builtin default)",
|
||||
)
|
||||
|
||||
|
||||
class SubagentsAppConfig(BaseModel):
|
||||
"""Configuration for the subagent system."""
|
||||
|
||||
timeout_seconds: int = Field(
|
||||
default=900,
|
||||
ge=1,
|
||||
description="Default timeout in seconds for all subagents (default: 900 = 15 minutes)",
|
||||
)
|
||||
max_turns: int | None = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
description="Optional default max-turn override for all subagents (None = keep builtin defaults)",
|
||||
)
|
||||
agents: dict[str, SubagentOverrideConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Per-agent configuration overrides keyed by agent name",
|
||||
)
|
||||
|
||||
def get_timeout_for(self, agent_name: str) -> int:
|
||||
"""Get the effective timeout for a specific agent.
|
||||
|
||||
Args:
|
||||
agent_name: The name of the subagent.
|
||||
|
||||
Returns:
|
||||
The timeout in seconds, using per-agent override if set, otherwise global default.
|
||||
"""
|
||||
override = self.agents.get(agent_name)
|
||||
if override is not None and override.timeout_seconds is not None:
|
||||
return override.timeout_seconds
|
||||
return self.timeout_seconds
|
||||
|
||||
def get_max_turns_for(self, agent_name: str, builtin_default: int) -> int:
|
||||
"""Get the effective max_turns for a specific agent."""
|
||||
override = self.agents.get(agent_name)
|
||||
if override is not None and override.max_turns is not None:
|
||||
return override.max_turns
|
||||
if self.max_turns is not None:
|
||||
return self.max_turns
|
||||
return builtin_default
|
||||
|
||||
|
||||
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
|
||||
|
||||
|
||||
def get_subagents_app_config() -> SubagentsAppConfig:
|
||||
"""Get the current subagents configuration."""
|
||||
return _subagents_config
|
||||
|
||||
|
||||
def load_subagents_config_from_dict(config_dict: dict) -> None:
|
||||
"""Load subagents configuration from a dictionary."""
|
||||
global _subagents_config
|
||||
_subagents_config = SubagentsAppConfig(**config_dict)
|
||||
|
||||
overrides_summary = {}
|
||||
for name, override in _subagents_config.agents.items():
|
||||
parts = []
|
||||
if override.timeout_seconds is not None:
|
||||
parts.append(f"timeout={override.timeout_seconds}s")
|
||||
if override.max_turns is not None:
|
||||
parts.append(f"max_turns={override.max_turns}")
|
||||
if parts:
|
||||
overrides_summary[name] = ", ".join(parts)
|
||||
|
||||
if overrides_summary:
|
||||
logger.info(
|
||||
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s",
|
||||
_subagents_config.timeout_seconds,
|
||||
_subagents_config.max_turns,
|
||||
overrides_summary,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
|
||||
_subagents_config.timeout_seconds,
|
||||
_subagents_config.max_turns,
|
||||
)
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Configuration for conversation summarization."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
ContextSizeType = Literal["fraction", "tokens", "messages"]
|
||||
|
||||
|
||||
class ContextSize(BaseModel):
|
||||
"""Context size specification for trigger or keep parameters."""
|
||||
|
||||
type: ContextSizeType = Field(description="Type of context size specification")
|
||||
value: int | float = Field(description="Value for the context size specification")
|
||||
|
||||
def to_tuple(self) -> tuple[ContextSizeType, int | float]:
|
||||
"""Convert to tuple format expected by SummarizationMiddleware."""
|
||||
return (self.type, self.value)
|
||||
|
||||
|
||||
class SummarizationConfig(BaseModel):
|
||||
"""Configuration for automatic conversation summarization."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable automatic conversation summarization",
|
||||
)
|
||||
model_name: str | None = Field(
|
||||
default=None,
|
||||
description="Model name to use for summarization (None = use a lightweight model)",
|
||||
)
|
||||
trigger: ContextSize | list[ContextSize] | None = Field(
|
||||
default=None,
|
||||
description="One or more thresholds that trigger summarization. When any threshold is met, summarization runs. "
|
||||
"Examples: {'type': 'messages', 'value': 50} triggers at 50 messages, "
|
||||
"{'type': 'tokens', 'value': 4000} triggers at 4000 tokens, "
|
||||
"{'type': 'fraction', 'value': 0.8} triggers at 80% of model's max input tokens",
|
||||
)
|
||||
keep: ContextSize = Field(
|
||||
default_factory=lambda: ContextSize(type="messages", value=20),
|
||||
description="Context retention policy after summarization. Specifies how much history to preserve. "
|
||||
"Examples: {'type': 'messages', 'value': 20} keeps 20 messages, "
|
||||
"{'type': 'tokens', 'value': 3000} keeps 3000 tokens, "
|
||||
"{'type': 'fraction', 'value': 0.3} keeps 30% of model's max input tokens",
|
||||
)
|
||||
trim_tokens_to_summarize: int | None = Field(
|
||||
default=4000,
|
||||
description="Maximum tokens to keep when preparing messages for summarization. Pass null to skip trimming.",
|
||||
)
|
||||
summary_prompt: str | None = Field(
|
||||
default=None,
|
||||
description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.",
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
_summarization_config: SummarizationConfig = SummarizationConfig()
|
||||
|
||||
|
||||
def get_summarization_config() -> SummarizationConfig:
|
||||
"""Get the current summarization configuration."""
|
||||
return _summarization_config
|
||||
|
||||
|
||||
def set_summarization_config(config: SummarizationConfig) -> None:
|
||||
"""Set the summarization configuration."""
|
||||
global _summarization_config
|
||||
_summarization_config = config
|
||||
|
||||
|
||||
def load_summarization_config_from_dict(config_dict: dict) -> None:
|
||||
"""Load summarization configuration from a dictionary."""
|
||||
global _summarization_config
|
||||
_summarization_config = SummarizationConfig(**config_dict)
|
||||
@@ -0,0 +1,53 @@
|
||||
"""Configuration for automatic thread title generation."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TitleConfig(BaseModel):
|
||||
"""Configuration for automatic thread title generation."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether to enable automatic title generation",
|
||||
)
|
||||
max_words: int = Field(
|
||||
default=6,
|
||||
ge=1,
|
||||
le=20,
|
||||
description="Maximum number of words in the generated title",
|
||||
)
|
||||
max_chars: int = Field(
|
||||
default=60,
|
||||
ge=10,
|
||||
le=200,
|
||||
description="Maximum number of characters in the generated title",
|
||||
)
|
||||
model_name: str | None = Field(
|
||||
default=None,
|
||||
description="Model name to use for title generation (None = use default model)",
|
||||
)
|
||||
prompt_template: str = Field(
|
||||
default=("Generate a concise title (max {max_words} words) for this conversation.\nUser: {user_msg}\nAssistant: {assistant_msg}\n\nReturn ONLY the title, no quotes, no explanation."),
|
||||
description="Prompt template for title generation",
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
_title_config: TitleConfig = TitleConfig()
|
||||
|
||||
|
||||
def get_title_config() -> TitleConfig:
|
||||
"""Get the current title configuration."""
|
||||
return _title_config
|
||||
|
||||
|
||||
def set_title_config(config: TitleConfig) -> None:
|
||||
"""Set the title configuration."""
|
||||
global _title_config
|
||||
_title_config = config
|
||||
|
||||
|
||||
def load_title_config_from_dict(config_dict: dict) -> None:
|
||||
"""Load title configuration from a dictionary."""
|
||||
global _title_config
|
||||
_title_config = TitleConfig(**config_dict)
|
||||
@@ -0,0 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TokenUsageConfig(BaseModel):
|
||||
"""Configuration for token usage tracking."""
|
||||
|
||||
enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
|
||||
@@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ToolGroupConfig(BaseModel):
|
||||
"""Config section for a tool group"""
|
||||
|
||||
name: str = Field(..., description="Unique name for the tool group")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
"""Config section for a tool"""
|
||||
|
||||
name: str = Field(..., description="Unique name for the tool")
|
||||
group: str = Field(..., description="Group name for the tool")
|
||||
use: str = Field(
|
||||
...,
|
||||
description="Variable name of the tool provider(e.g. deerflow.sandbox.tools:bash_tool)",
|
||||
)
|
||||
model_config = ConfigDict(extra="allow")
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Configuration for deferred tool loading via tool_search."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ToolSearchConfig(BaseModel):
|
||||
"""Configuration for deferred tool loading via tool_search.
|
||||
|
||||
When enabled, MCP tools are not loaded into the agent's context directly.
|
||||
Instead, they are listed by name in the system prompt and discoverable
|
||||
via the tool_search tool at runtime.
|
||||
"""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=False,
|
||||
description="Defer tools and enable tool_search",
|
||||
)
|
||||
|
||||
|
||||
_tool_search_config: ToolSearchConfig | None = None
|
||||
|
||||
|
||||
def get_tool_search_config() -> ToolSearchConfig:
|
||||
"""Get the tool search config, loading from AppConfig if needed."""
|
||||
global _tool_search_config
|
||||
if _tool_search_config is None:
|
||||
_tool_search_config = ToolSearchConfig()
|
||||
return _tool_search_config
|
||||
|
||||
|
||||
def load_tool_search_config_from_dict(data: dict) -> ToolSearchConfig:
|
||||
"""Load tool search config from a dict (called during AppConfig loading)."""
|
||||
global _tool_search_config
|
||||
_tool_search_config = ToolSearchConfig.model_validate(data)
|
||||
return _tool_search_config
|
||||
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
import threading
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
_config_lock = threading.Lock()
|
||||
|
||||
|
||||
class LangSmithTracingConfig(BaseModel):
|
||||
"""Configuration for LangSmith tracing."""
|
||||
|
||||
enabled: bool = Field(...)
|
||||
api_key: str | None = Field(...)
|
||||
project: str = Field(...)
|
||||
endpoint: str = Field(...)
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return self.enabled and bool(self.api_key)
|
||||
|
||||
def validate(self) -> None:
|
||||
if self.enabled and not self.api_key:
|
||||
raise ValueError("LangSmith tracing is enabled but LANGSMITH_API_KEY (or LANGCHAIN_API_KEY) is not set.")
|
||||
|
||||
|
||||
class LangfuseTracingConfig(BaseModel):
|
||||
"""Configuration for Langfuse tracing."""
|
||||
|
||||
enabled: bool = Field(...)
|
||||
public_key: str | None = Field(...)
|
||||
secret_key: str | None = Field(...)
|
||||
host: str = Field(...)
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return self.enabled and bool(self.public_key) and bool(self.secret_key)
|
||||
|
||||
def validate(self) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
missing: list[str] = []
|
||||
if not self.public_key:
|
||||
missing.append("LANGFUSE_PUBLIC_KEY")
|
||||
if not self.secret_key:
|
||||
missing.append("LANGFUSE_SECRET_KEY")
|
||||
if missing:
|
||||
raise ValueError(f"Langfuse tracing is enabled but required settings are missing: {', '.join(missing)}")
|
||||
|
||||
|
||||
class TracingConfig(BaseModel):
|
||||
"""Tracing configuration for supported providers."""
|
||||
|
||||
langsmith: LangSmithTracingConfig = Field(...)
|
||||
langfuse: LangfuseTracingConfig = Field(...)
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.enabled_providers)
|
||||
|
||||
@property
|
||||
def explicitly_enabled_providers(self) -> list[str]:
|
||||
enabled: list[str] = []
|
||||
if self.langsmith.enabled:
|
||||
enabled.append("langsmith")
|
||||
if self.langfuse.enabled:
|
||||
enabled.append("langfuse")
|
||||
return enabled
|
||||
|
||||
@property
|
||||
def enabled_providers(self) -> list[str]:
|
||||
enabled: list[str] = []
|
||||
if self.langsmith.is_configured:
|
||||
enabled.append("langsmith")
|
||||
if self.langfuse.is_configured:
|
||||
enabled.append("langfuse")
|
||||
return enabled
|
||||
|
||||
def validate_enabled(self) -> None:
|
||||
self.langsmith.validate()
|
||||
self.langfuse.validate()
|
||||
|
||||
|
||||
_tracing_config: TracingConfig | None = None
|
||||
|
||||
|
||||
_TRUTHY_VALUES = {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _env_flag_preferred(*names: str) -> bool:
|
||||
"""Return the boolean value of the first env var that is present and non-empty."""
|
||||
for name in names:
|
||||
value = os.environ.get(name)
|
||||
if value is not None and value.strip():
|
||||
return value.strip().lower() in _TRUTHY_VALUES
|
||||
return False
|
||||
|
||||
|
||||
def _first_env_value(*names: str) -> str | None:
|
||||
"""Return the first non-empty environment value from candidate names."""
|
||||
for name in names:
|
||||
value = os.environ.get(name)
|
||||
if value and value.strip():
|
||||
return value.strip()
|
||||
return None
|
||||
|
||||
|
||||
def get_tracing_config() -> TracingConfig:
|
||||
"""Get the current tracing configuration from environment variables."""
|
||||
global _tracing_config
|
||||
if _tracing_config is not None:
|
||||
return _tracing_config
|
||||
with _config_lock:
|
||||
if _tracing_config is not None:
|
||||
return _tracing_config
|
||||
_tracing_config = TracingConfig(
|
||||
langsmith=LangSmithTracingConfig(
|
||||
enabled=_env_flag_preferred("LANGSMITH_TRACING", "LANGCHAIN_TRACING_V2", "LANGCHAIN_TRACING"),
|
||||
api_key=_first_env_value("LANGSMITH_API_KEY", "LANGCHAIN_API_KEY"),
|
||||
project=_first_env_value("LANGSMITH_PROJECT", "LANGCHAIN_PROJECT") or "deer-flow",
|
||||
endpoint=_first_env_value("LANGSMITH_ENDPOINT", "LANGCHAIN_ENDPOINT") or "https://api.smith.langchain.com",
|
||||
),
|
||||
langfuse=LangfuseTracingConfig(
|
||||
enabled=_env_flag_preferred("LANGFUSE_TRACING"),
|
||||
public_key=_first_env_value("LANGFUSE_PUBLIC_KEY"),
|
||||
secret_key=_first_env_value("LANGFUSE_SECRET_KEY"),
|
||||
host=_first_env_value("LANGFUSE_BASE_URL") or "https://cloud.langfuse.com",
|
||||
),
|
||||
)
|
||||
return _tracing_config
|
||||
|
||||
|
||||
def get_enabled_tracing_providers() -> list[str]:
|
||||
"""Return the configured tracing providers that are enabled and complete."""
|
||||
return get_tracing_config().enabled_providers
|
||||
|
||||
|
||||
def get_explicitly_enabled_tracing_providers() -> list[str]:
|
||||
"""Return tracing providers explicitly enabled by config, even if incomplete."""
|
||||
return get_tracing_config().explicitly_enabled_providers
|
||||
|
||||
|
||||
def validate_enabled_tracing_providers() -> None:
|
||||
"""Validate that any explicitly enabled providers are fully configured."""
|
||||
get_tracing_config().validate_enabled()
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
"""Check if any tracing provider is enabled and fully configured."""
|
||||
return get_tracing_config().is_configured
|
||||
@@ -0,0 +1,14 @@
|
||||
"""Pre-tool-call authorization middleware."""
|
||||
|
||||
from deerflow.guardrails.builtin import AllowlistProvider
|
||||
from deerflow.guardrails.middleware import GuardrailMiddleware
|
||||
from deerflow.guardrails.provider import GuardrailDecision, GuardrailProvider, GuardrailReason, GuardrailRequest
|
||||
|
||||
__all__ = [
|
||||
"AllowlistProvider",
|
||||
"GuardrailDecision",
|
||||
"GuardrailMiddleware",
|
||||
"GuardrailProvider",
|
||||
"GuardrailReason",
|
||||
"GuardrailRequest",
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Built-in guardrail providers that ship with DeerFlow."""
|
||||
|
||||
from deerflow.guardrails.provider import GuardrailDecision, GuardrailReason, GuardrailRequest
|
||||
|
||||
|
||||
class AllowlistProvider:
|
||||
"""Simple allowlist/denylist provider. No external dependencies."""
|
||||
|
||||
name = "allowlist"
|
||||
|
||||
def __init__(self, *, allowed_tools: list[str] | None = None, denied_tools: list[str] | None = None):
|
||||
self._allowed = set(allowed_tools) if allowed_tools else None
|
||||
self._denied = set(denied_tools) if denied_tools else set()
|
||||
|
||||
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||
if self._allowed is not None and request.tool_name not in self._allowed:
|
||||
return GuardrailDecision(allow=False, reasons=[GuardrailReason(code="oap.tool_not_allowed", message=f"tool '{request.tool_name}' not in allowlist")])
|
||||
if request.tool_name in self._denied:
|
||||
return GuardrailDecision(allow=False, reasons=[GuardrailReason(code="oap.tool_not_allowed", message=f"tool '{request.tool_name}' is denied")])
|
||||
return GuardrailDecision(allow=True, reasons=[GuardrailReason(code="oap.allowed")])
|
||||
|
||||
async def aevaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||
return self.evaluate(request)
|
||||
@@ -0,0 +1,98 @@
|
||||
"""GuardrailMiddleware - evaluates tool calls against a GuardrailProvider before execution."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.guardrails.provider import GuardrailDecision, GuardrailProvider, GuardrailReason, GuardrailRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GuardrailMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Evaluate tool calls against a GuardrailProvider before execution.
|
||||
|
||||
Denied calls return an error ToolMessage so the agent can adapt.
|
||||
If the provider raises, behavior depends on fail_closed:
|
||||
- True (default): block the call
|
||||
- False: allow it through with a warning
|
||||
"""
|
||||
|
||||
def __init__(self, provider: GuardrailProvider, *, fail_closed: bool = True, passport: str | None = None):
|
||||
self.provider = provider
|
||||
self.fail_closed = fail_closed
|
||||
self.passport = passport
|
||||
|
||||
def _build_request(self, request: ToolCallRequest) -> GuardrailRequest:
|
||||
return GuardrailRequest(
|
||||
tool_name=str(request.tool_call.get("name", "")),
|
||||
tool_input=request.tool_call.get("args", {}),
|
||||
agent_id=self.passport,
|
||||
timestamp=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
||||
def _build_denied_message(self, request: ToolCallRequest, decision: GuardrailDecision) -> ToolMessage:
|
||||
tool_name = str(request.tool_call.get("name", "unknown_tool"))
|
||||
tool_call_id = str(request.tool_call.get("id", "missing_id"))
|
||||
reason_text = decision.reasons[0].message if decision.reasons else "blocked by guardrail policy"
|
||||
reason_code = decision.reasons[0].code if decision.reasons else "oap.denied"
|
||||
return ToolMessage(
|
||||
content=f"Guardrail denied: tool '{tool_name}' was blocked ({reason_code}). Reason: {reason_text}. Choose an alternative approach.",
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
gr = self._build_request(request)
|
||||
try:
|
||||
decision = self.provider.evaluate(gr)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Guardrail provider error (sync)")
|
||||
if self.fail_closed:
|
||||
decision = GuardrailDecision(allow=False, reasons=[GuardrailReason(code="oap.evaluator_error", message="guardrail provider error (fail-closed)")])
|
||||
else:
|
||||
return handler(request)
|
||||
if not decision.allow:
|
||||
logger.warning("Guardrail denied: tool=%s policy=%s code=%s", gr.tool_name, decision.policy_id, decision.reasons[0].code if decision.reasons else "unknown")
|
||||
return self._build_denied_message(request, decision)
|
||||
return handler(request)
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
gr = self._build_request(request)
|
||||
try:
|
||||
decision = await self.provider.aevaluate(gr)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Guardrail provider error (async)")
|
||||
if self.fail_closed:
|
||||
decision = GuardrailDecision(allow=False, reasons=[GuardrailReason(code="oap.evaluator_error", message="guardrail provider error (fail-closed)")])
|
||||
else:
|
||||
return await handler(request)
|
||||
if not decision.allow:
|
||||
logger.warning("Guardrail denied: tool=%s policy=%s code=%s", gr.tool_name, decision.policy_id, decision.reasons[0].code if decision.reasons else "unknown")
|
||||
return self._build_denied_message(request, decision)
|
||||
return await handler(request)
|
||||
@@ -0,0 +1,56 @@
|
||||
"""GuardrailProvider protocol and data structures for pre-tool-call authorization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuardrailRequest:
|
||||
"""Context passed to the provider for each tool call."""
|
||||
|
||||
tool_name: str
|
||||
tool_input: dict[str, Any]
|
||||
agent_id: str | None = None
|
||||
thread_id: str | None = None
|
||||
is_subagent: bool = False
|
||||
timestamp: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuardrailReason:
|
||||
"""Structured reason for an allow/deny decision (OAP reason object)."""
|
||||
|
||||
code: str
|
||||
message: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuardrailDecision:
|
||||
"""Provider's allow/deny verdict (aligned with OAP Decision object)."""
|
||||
|
||||
allow: bool
|
||||
reasons: list[GuardrailReason] = field(default_factory=list)
|
||||
policy_id: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class GuardrailProvider(Protocol):
|
||||
"""Contract for pluggable tool-call authorization.
|
||||
|
||||
Any class with these methods works - no base class required.
|
||||
Providers are loaded by class path via resolve_variable(),
|
||||
the same mechanism DeerFlow uses for models, tools, and sandbox.
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||
"""Evaluate whether a tool call should proceed."""
|
||||
...
|
||||
|
||||
async def aevaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||
"""Async variant."""
|
||||
...
|
||||
14
deer-flow/backend/packages/harness/deerflow/mcp/__init__.py
Normal file
14
deer-flow/backend/packages/harness/deerflow/mcp/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""MCP (Model Context Protocol) integration using langchain-mcp-adapters."""
|
||||
|
||||
from .cache import get_cached_mcp_tools, initialize_mcp_tools, reset_mcp_tools_cache
|
||||
from .client import build_server_params, build_servers_config
|
||||
from .tools import get_mcp_tools
|
||||
|
||||
__all__ = [
|
||||
"build_server_params",
|
||||
"build_servers_config",
|
||||
"get_mcp_tools",
|
||||
"initialize_mcp_tools",
|
||||
"get_cached_mcp_tools",
|
||||
"reset_mcp_tools_cache",
|
||||
]
|
||||
138
deer-flow/backend/packages/harness/deerflow/mcp/cache.py
Normal file
138
deer-flow/backend/packages/harness/deerflow/mcp/cache.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Cache for MCP tools to avoid repeated loading."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_mcp_tools_cache: list[BaseTool] | None = None
|
||||
_cache_initialized = False
|
||||
_initialization_lock = asyncio.Lock()
|
||||
_config_mtime: float | None = None # Track config file modification time
|
||||
|
||||
|
||||
def _get_config_mtime() -> float | None:
|
||||
"""Get the modification time of the extensions config file.
|
||||
|
||||
Returns:
|
||||
The modification time as a float, or None if the file doesn't exist.
|
||||
"""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
|
||||
config_path = ExtensionsConfig.resolve_config_path()
|
||||
if config_path and config_path.exists():
|
||||
return os.path.getmtime(config_path)
|
||||
return None
|
||||
|
||||
|
||||
def _is_cache_stale() -> bool:
|
||||
"""Check if the cache is stale due to config file changes.
|
||||
|
||||
Returns:
|
||||
True if the cache should be invalidated, False otherwise.
|
||||
"""
|
||||
global _config_mtime
|
||||
|
||||
if not _cache_initialized:
|
||||
return False # Not initialized yet, not stale
|
||||
|
||||
current_mtime = _get_config_mtime()
|
||||
|
||||
# If we couldn't get mtime before or now, assume not stale
|
||||
if _config_mtime is None or current_mtime is None:
|
||||
return False
|
||||
|
||||
# If the config file has been modified since we cached, it's stale
|
||||
if current_mtime > _config_mtime:
|
||||
logger.info(f"MCP config file has been modified (mtime: {_config_mtime} -> {current_mtime}), cache is stale")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def initialize_mcp_tools() -> list[BaseTool]:
|
||||
"""Initialize and cache MCP tools.
|
||||
|
||||
This should be called once at application startup.
|
||||
|
||||
Returns:
|
||||
List of LangChain tools from all enabled MCP servers.
|
||||
"""
|
||||
global _mcp_tools_cache, _cache_initialized, _config_mtime
|
||||
|
||||
async with _initialization_lock:
|
||||
if _cache_initialized:
|
||||
logger.info("MCP tools already initialized")
|
||||
return _mcp_tools_cache or []
|
||||
|
||||
from deerflow.mcp.tools import get_mcp_tools
|
||||
|
||||
logger.info("Initializing MCP tools...")
|
||||
_mcp_tools_cache = await get_mcp_tools()
|
||||
_cache_initialized = True
|
||||
_config_mtime = _get_config_mtime() # Record config file mtime
|
||||
logger.info(f"MCP tools initialized: {len(_mcp_tools_cache)} tool(s) loaded (config mtime: {_config_mtime})")
|
||||
|
||||
return _mcp_tools_cache
|
||||
|
||||
|
||||
def get_cached_mcp_tools() -> list[BaseTool]:
|
||||
"""Get cached MCP tools with lazy initialization.
|
||||
|
||||
If tools are not initialized, automatically initializes them.
|
||||
This ensures MCP tools work in both FastAPI and LangGraph Studio contexts.
|
||||
|
||||
Also checks if the config file has been modified since last initialization,
|
||||
and re-initializes if needed. This ensures that changes made through the
|
||||
Gateway API (which runs in a separate process) are reflected in the
|
||||
LangGraph Server.
|
||||
|
||||
Returns:
|
||||
List of cached MCP tools.
|
||||
"""
|
||||
global _cache_initialized
|
||||
|
||||
# Check if cache is stale due to config file changes
|
||||
if _is_cache_stale():
|
||||
logger.info("MCP cache is stale, resetting for re-initialization...")
|
||||
reset_mcp_tools_cache()
|
||||
|
||||
if not _cache_initialized:
|
||||
logger.info("MCP tools not initialized, performing lazy initialization...")
|
||||
try:
|
||||
# Try to initialize in the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If loop is already running (e.g., in LangGraph Studio),
|
||||
# we need to create a new loop in a thread
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(asyncio.run, initialize_mcp_tools())
|
||||
future.result()
|
||||
else:
|
||||
# If no loop is running, we can use the current loop
|
||||
loop.run_until_complete(initialize_mcp_tools())
|
||||
except RuntimeError:
|
||||
# No event loop exists, create one
|
||||
asyncio.run(initialize_mcp_tools())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to lazy-initialize MCP tools: {e}")
|
||||
return []
|
||||
|
||||
return _mcp_tools_cache or []
|
||||
|
||||
|
||||
def reset_mcp_tools_cache() -> None:
|
||||
"""Reset the MCP tools cache.
|
||||
|
||||
This is useful for testing or when you want to reload MCP tools.
|
||||
"""
|
||||
global _mcp_tools_cache, _cache_initialized, _config_mtime
|
||||
_mcp_tools_cache = None
|
||||
_cache_initialized = False
|
||||
_config_mtime = None
|
||||
logger.info("MCP tools cache reset")
|
||||
68
deer-flow/backend/packages/harness/deerflow/mcp/client.py
Normal file
68
deer-flow/backend/packages/harness/deerflow/mcp/client.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""MCP client using langchain-mcp-adapters."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_server_params(server_name: str, config: McpServerConfig) -> dict[str, Any]:
|
||||
"""Build server parameters for MultiServerMCPClient.
|
||||
|
||||
Args:
|
||||
server_name: Name of the MCP server.
|
||||
config: Configuration for the MCP server.
|
||||
|
||||
Returns:
|
||||
Dictionary of server parameters for langchain-mcp-adapters.
|
||||
"""
|
||||
transport_type = config.type or "stdio"
|
||||
params: dict[str, Any] = {"transport": transport_type}
|
||||
|
||||
if transport_type == "stdio":
|
||||
if not config.command:
|
||||
raise ValueError(f"MCP server '{server_name}' with stdio transport requires 'command' field")
|
||||
params["command"] = config.command
|
||||
params["args"] = config.args
|
||||
# Add environment variables if present
|
||||
if config.env:
|
||||
params["env"] = config.env
|
||||
elif transport_type in ("sse", "http"):
|
||||
if not config.url:
|
||||
raise ValueError(f"MCP server '{server_name}' with {transport_type} transport requires 'url' field")
|
||||
params["url"] = config.url
|
||||
# Add headers if present
|
||||
if config.headers:
|
||||
params["headers"] = config.headers
|
||||
else:
|
||||
raise ValueError(f"MCP server '{server_name}' has unsupported transport type: {transport_type}")
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def build_servers_config(extensions_config: ExtensionsConfig) -> dict[str, dict[str, Any]]:
|
||||
"""Build servers configuration for MultiServerMCPClient.
|
||||
|
||||
Args:
|
||||
extensions_config: Extensions configuration containing all MCP servers.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping server names to their parameters.
|
||||
"""
|
||||
enabled_servers = extensions_config.get_enabled_mcp_servers()
|
||||
|
||||
if not enabled_servers:
|
||||
logger.info("No enabled MCP servers found")
|
||||
return {}
|
||||
|
||||
servers_config = {}
|
||||
for server_name, server_config in enabled_servers.items():
|
||||
try:
|
||||
servers_config[server_name] = build_server_params(server_name, server_config)
|
||||
logger.info(f"Configured MCP server: {server_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure MCP server '{server_name}': {e}")
|
||||
|
||||
return servers_config
|
||||
150
deer-flow/backend/packages/harness/deerflow/mcp/oauth.py
Normal file
150
deer-flow/backend/packages/harness/deerflow/mcp/oauth.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""OAuth token support for MCP HTTP/SSE servers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpOAuthConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OAuthToken:
|
||||
"""Cached OAuth token."""
|
||||
|
||||
access_token: str
|
||||
token_type: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
class OAuthTokenManager:
|
||||
"""Acquire/cache/refresh OAuth tokens for MCP servers."""
|
||||
|
||||
def __init__(self, oauth_by_server: dict[str, McpOAuthConfig]):
|
||||
self._oauth_by_server = oauth_by_server
|
||||
self._tokens: dict[str, _OAuthToken] = {}
|
||||
self._locks: dict[str, asyncio.Lock] = {name: asyncio.Lock() for name in oauth_by_server}
|
||||
|
||||
@classmethod
|
||||
def from_extensions_config(cls, extensions_config: ExtensionsConfig) -> OAuthTokenManager:
|
||||
oauth_by_server: dict[str, McpOAuthConfig] = {}
|
||||
for server_name, server_config in extensions_config.get_enabled_mcp_servers().items():
|
||||
if server_config.oauth and server_config.oauth.enabled:
|
||||
oauth_by_server[server_name] = server_config.oauth
|
||||
return cls(oauth_by_server)
|
||||
|
||||
def has_oauth_servers(self) -> bool:
|
||||
return bool(self._oauth_by_server)
|
||||
|
||||
def oauth_server_names(self) -> list[str]:
|
||||
return list(self._oauth_by_server.keys())
|
||||
|
||||
async def get_authorization_header(self, server_name: str) -> str | None:
|
||||
oauth = self._oauth_by_server.get(server_name)
|
||||
if not oauth:
|
||||
return None
|
||||
|
||||
token = self._tokens.get(server_name)
|
||||
if token and not self._is_expiring(token, oauth):
|
||||
return f"{token.token_type} {token.access_token}"
|
||||
|
||||
lock = self._locks[server_name]
|
||||
async with lock:
|
||||
token = self._tokens.get(server_name)
|
||||
if token and not self._is_expiring(token, oauth):
|
||||
return f"{token.token_type} {token.access_token}"
|
||||
|
||||
fresh = await self._fetch_token(oauth)
|
||||
self._tokens[server_name] = fresh
|
||||
logger.info(f"Refreshed OAuth access token for MCP server: {server_name}")
|
||||
return f"{fresh.token_type} {fresh.access_token}"
|
||||
|
||||
@staticmethod
|
||||
def _is_expiring(token: _OAuthToken, oauth: McpOAuthConfig) -> bool:
|
||||
now = datetime.now(UTC)
|
||||
return token.expires_at <= now + timedelta(seconds=max(oauth.refresh_skew_seconds, 0))
|
||||
|
||||
async def _fetch_token(self, oauth: McpOAuthConfig) -> _OAuthToken:
|
||||
import httpx # pyright: ignore[reportMissingImports]
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": oauth.grant_type,
|
||||
**oauth.extra_token_params,
|
||||
}
|
||||
|
||||
if oauth.scope:
|
||||
data["scope"] = oauth.scope
|
||||
if oauth.audience:
|
||||
data["audience"] = oauth.audience
|
||||
|
||||
if oauth.grant_type == "client_credentials":
|
||||
if not oauth.client_id or not oauth.client_secret:
|
||||
raise ValueError("OAuth client_credentials requires client_id and client_secret")
|
||||
data["client_id"] = oauth.client_id
|
||||
data["client_secret"] = oauth.client_secret
|
||||
elif oauth.grant_type == "refresh_token":
|
||||
if not oauth.refresh_token:
|
||||
raise ValueError("OAuth refresh_token grant requires refresh_token")
|
||||
data["refresh_token"] = oauth.refresh_token
|
||||
if oauth.client_id:
|
||||
data["client_id"] = oauth.client_id
|
||||
if oauth.client_secret:
|
||||
data["client_secret"] = oauth.client_secret
|
||||
else:
|
||||
raise ValueError(f"Unsupported OAuth grant type: {oauth.grant_type}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
response = await client.post(oauth.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
access_token = payload.get(oauth.token_field)
|
||||
if not access_token:
|
||||
raise ValueError(f"OAuth token response missing '{oauth.token_field}'")
|
||||
|
||||
token_type = str(payload.get(oauth.token_type_field, oauth.default_token_type) or oauth.default_token_type)
|
||||
|
||||
expires_in_raw = payload.get(oauth.expires_in_field, 3600)
|
||||
try:
|
||||
expires_in = int(expires_in_raw)
|
||||
except (TypeError, ValueError):
|
||||
expires_in = 3600
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=max(expires_in, 1))
|
||||
return _OAuthToken(access_token=access_token, token_type=token_type, expires_at=expires_at)
|
||||
|
||||
|
||||
def build_oauth_tool_interceptor(extensions_config: ExtensionsConfig) -> Any | None:
|
||||
"""Build a tool interceptor that injects OAuth Authorization headers."""
|
||||
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
|
||||
if not token_manager.has_oauth_servers():
|
||||
return None
|
||||
|
||||
async def oauth_interceptor(request: Any, handler: Any) -> Any:
|
||||
header = await token_manager.get_authorization_header(request.server_name)
|
||||
if not header:
|
||||
return await handler(request)
|
||||
|
||||
updated_headers = dict(request.headers or {})
|
||||
updated_headers["Authorization"] = header
|
||||
return await handler(request.override(headers=updated_headers))
|
||||
|
||||
return oauth_interceptor
|
||||
|
||||
|
||||
async def get_initial_oauth_headers(extensions_config: ExtensionsConfig) -> dict[str, str]:
|
||||
"""Get initial OAuth Authorization headers for MCP server connections."""
|
||||
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
|
||||
if not token_manager.has_oauth_servers():
|
||||
return {}
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
for server_name in token_manager.oauth_server_names():
|
||||
headers[server_name] = await token_manager.get_authorization_header(server_name) or ""
|
||||
|
||||
return {name: value for name, value in headers.items() if value}
|
||||
113
deer-flow/backend/packages/harness/deerflow/mcp/tools.py
Normal file
113
deer-flow/backend/packages/harness/deerflow/mcp/tools.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Load MCP tools using langchain-mcp-adapters."""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.mcp.client import build_servers_config
|
||||
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global thread pool for sync tool invocation in async environments
|
||||
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="mcp-sync-tool")
|
||||
|
||||
# Register shutdown hook for the global executor
|
||||
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
|
||||
|
||||
|
||||
def _make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
|
||||
"""Build a synchronous wrapper for an asynchronous tool coroutine.
|
||||
|
||||
Args:
|
||||
coro: The tool's asynchronous coroutine.
|
||||
tool_name: Name of the tool (for logging).
|
||||
|
||||
Returns:
|
||||
A synchronous function that correctly handles nested event loops.
|
||||
"""
|
||||
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
try:
|
||||
if loop is not None and loop.is_running():
|
||||
# Use global executor to avoid nested loop issues and improve performance
|
||||
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
|
||||
return future.result()
|
||||
else:
|
||||
return asyncio.run(coro(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.error(f"Error invoking MCP tool '{tool_name}' via sync wrapper: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
async def get_mcp_tools() -> list[BaseTool]:
|
||||
"""Get all tools from enabled MCP servers.
|
||||
|
||||
Returns:
|
||||
List of LangChain tools from all enabled MCP servers.
|
||||
"""
|
||||
try:
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
except ImportError:
|
||||
logger.warning("langchain-mcp-adapters not installed. Install it to enable MCP tools: pip install langchain-mcp-adapters")
|
||||
return []
|
||||
|
||||
# NOTE: We use ExtensionsConfig.from_file() instead of get_extensions_config()
|
||||
# to always read the latest configuration from disk. This ensures that changes
|
||||
# made through the Gateway API (which runs in a separate process) are immediately
|
||||
# reflected when initializing MCP tools.
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
servers_config = build_servers_config(extensions_config)
|
||||
|
||||
if not servers_config:
|
||||
logger.info("No enabled MCP servers configured")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Create the multi-server MCP client
|
||||
logger.info(f"Initializing MCP client with {len(servers_config)} server(s)")
|
||||
|
||||
# Inject initial OAuth headers for server connections (tool discovery/session init)
|
||||
initial_oauth_headers = await get_initial_oauth_headers(extensions_config)
|
||||
for server_name, auth_header in initial_oauth_headers.items():
|
||||
if server_name not in servers_config:
|
||||
continue
|
||||
if servers_config[server_name].get("transport") in ("sse", "http"):
|
||||
existing_headers = dict(servers_config[server_name].get("headers", {}))
|
||||
existing_headers["Authorization"] = auth_header
|
||||
servers_config[server_name]["headers"] = existing_headers
|
||||
|
||||
tool_interceptors = []
|
||||
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
|
||||
if oauth_interceptor is not None:
|
||||
tool_interceptors.append(oauth_interceptor)
|
||||
|
||||
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
|
||||
|
||||
# Get all tools from all servers
|
||||
tools = await client.get_tools()
|
||||
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
|
||||
|
||||
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||||
for tool in tools:
|
||||
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
|
||||
tool.func = _make_sync_tool_wrapper(tool.coroutine, tool.name)
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)
|
||||
return []
|
||||
@@ -0,0 +1,3 @@
|
||||
from .factory import create_chat_model
|
||||
|
||||
__all__ = ["create_chat_model"]
|
||||
@@ -0,0 +1,348 @@
|
||||
"""Custom Claude provider with OAuth Bearer auth, prompt caching, and smart thinking.
|
||||
|
||||
Supports two authentication modes:
|
||||
1. Standard API key (x-api-key header) — default ChatAnthropic behavior
|
||||
2. Claude Code OAuth token (Authorization: Bearer header)
|
||||
- Detected by sk-ant-oat prefix
|
||||
- Requires anthropic-beta: oauth-2025-04-20,claude-code-20250219
|
||||
- Requires billing header in system prompt for all OAuth requests
|
||||
|
||||
Auto-loads credentials from explicit runtime handoff:
|
||||
- $ANTHROPIC_API_KEY environment variable
|
||||
- $CLAUDE_CODE_OAUTH_TOKEN or $ANTHROPIC_AUTH_TOKEN
|
||||
- $CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR
|
||||
- $CLAUDE_CODE_CREDENTIALS_PATH
|
||||
- ~/.claude/.credentials.json
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import anthropic
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_RETRIES = 3
|
||||
THINKING_BUDGET_RATIO = 0.8
|
||||
|
||||
# Billing header required by Anthropic API for OAuth token access.
|
||||
# Must be the first system prompt block. Format mirrors Claude Code CLI.
|
||||
# Override with ANTHROPIC_BILLING_HEADER env var if the hardcoded version drifts.
|
||||
_DEFAULT_BILLING_HEADER = "x-anthropic-billing-header: cc_version=2.1.85.351; cc_entrypoint=cli; cch=6c6d5;"
|
||||
OAUTH_BILLING_HEADER = os.environ.get("ANTHROPIC_BILLING_HEADER", _DEFAULT_BILLING_HEADER)
|
||||
|
||||
|
||||
class ClaudeChatModel(ChatAnthropic):
|
||||
"""ChatAnthropic with OAuth Bearer auth, prompt caching, and smart thinking.
|
||||
|
||||
Config example:
|
||||
- name: claude-sonnet-4.6
|
||||
use: deerflow.models.claude_provider:ClaudeChatModel
|
||||
model: claude-sonnet-4-6
|
||||
max_tokens: 16384
|
||||
enable_prompt_caching: true
|
||||
"""
|
||||
|
||||
# Custom fields
|
||||
enable_prompt_caching: bool = True
|
||||
prompt_cache_size: int = 3
|
||||
auto_thinking_budget: bool = True
|
||||
retry_max_attempts: int = MAX_RETRIES
|
||||
_is_oauth: bool = PrivateAttr(default=False)
|
||||
_oauth_access_token: str = PrivateAttr(default="")
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
def _validate_retry_config(self) -> None:
|
||||
if self.retry_max_attempts < 1:
|
||||
raise ValueError("retry_max_attempts must be >= 1")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Auto-load credentials and configure OAuth if needed."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from deerflow.models.credential_loader import (
|
||||
OAUTH_ANTHROPIC_BETAS,
|
||||
is_oauth_token,
|
||||
load_claude_code_credential,
|
||||
)
|
||||
|
||||
self._validate_retry_config()
|
||||
|
||||
# Extract actual key value (SecretStr.str() returns '**********')
|
||||
current_key = ""
|
||||
if self.anthropic_api_key:
|
||||
if hasattr(self.anthropic_api_key, "get_secret_value"):
|
||||
current_key = self.anthropic_api_key.get_secret_value()
|
||||
else:
|
||||
current_key = str(self.anthropic_api_key)
|
||||
|
||||
# Try the explicit Claude Code OAuth handoff sources if no valid key.
|
||||
if not current_key or current_key in ("your-anthropic-api-key",):
|
||||
cred = load_claude_code_credential()
|
||||
if cred:
|
||||
current_key = cred.access_token
|
||||
logger.info(f"Using Claude Code CLI credential (source: {cred.source})")
|
||||
else:
|
||||
logger.warning("No Anthropic API key or explicit Claude Code OAuth credential found.")
|
||||
|
||||
# Detect OAuth token and configure Bearer auth
|
||||
if is_oauth_token(current_key):
|
||||
self._is_oauth = True
|
||||
self._oauth_access_token = current_key
|
||||
# Set the token as api_key temporarily (will be swapped to auth_token on client)
|
||||
self.anthropic_api_key = SecretStr(current_key)
|
||||
# Add required beta headers for OAuth
|
||||
self.default_headers = {
|
||||
**(self.default_headers or {}),
|
||||
"anthropic-beta": OAUTH_ANTHROPIC_BETAS,
|
||||
}
|
||||
# OAuth tokens have a limit of 4 cache_control blocks — disable prompt caching
|
||||
self.enable_prompt_caching = False
|
||||
logger.info("OAuth token detected — will use Authorization: Bearer header")
|
||||
else:
|
||||
if current_key:
|
||||
self.anthropic_api_key = SecretStr(current_key)
|
||||
|
||||
# Ensure api_key is SecretStr
|
||||
if isinstance(self.anthropic_api_key, str):
|
||||
self.anthropic_api_key = SecretStr(self.anthropic_api_key)
|
||||
|
||||
super().model_post_init(__context)
|
||||
|
||||
# Patch clients immediately after creation for OAuth Bearer auth.
|
||||
# This must happen after super() because clients are lazily created.
|
||||
if self._is_oauth:
|
||||
self._patch_client_oauth(self._client)
|
||||
self._patch_client_oauth(self._async_client)
|
||||
|
||||
def _patch_client_oauth(self, client: Any) -> None:
|
||||
"""Swap api_key → auth_token on an Anthropic SDK client for OAuth Bearer auth."""
|
||||
if hasattr(client, "api_key") and hasattr(client, "auth_token"):
|
||||
client.api_key = None
|
||||
client.auth_token = self._oauth_access_token
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: Any,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Override to inject prompt caching, thinking budget, and OAuth billing."""
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
|
||||
if self._is_oauth:
|
||||
self._apply_oauth_billing(payload)
|
||||
|
||||
if self.enable_prompt_caching:
|
||||
self._apply_prompt_caching(payload)
|
||||
|
||||
if self.auto_thinking_budget:
|
||||
self._apply_thinking_budget(payload)
|
||||
|
||||
return payload
|
||||
|
||||
def _apply_oauth_billing(self, payload: dict) -> None:
|
||||
"""Inject the billing header block required for all OAuth requests.
|
||||
|
||||
The billing block is always placed first in the system list, removing any
|
||||
existing occurrence to avoid duplication or out-of-order positioning.
|
||||
"""
|
||||
billing_block = {"type": "text", "text": OAUTH_BILLING_HEADER}
|
||||
|
||||
system = payload.get("system")
|
||||
if isinstance(system, list):
|
||||
# Remove any existing billing blocks, then insert a single one at index 0.
|
||||
filtered = [b for b in system if not (isinstance(b, dict) and OAUTH_BILLING_HEADER in b.get("text", ""))]
|
||||
payload["system"] = [billing_block] + filtered
|
||||
elif isinstance(system, str):
|
||||
if OAUTH_BILLING_HEADER in system:
|
||||
payload["system"] = [billing_block]
|
||||
else:
|
||||
payload["system"] = [billing_block, {"type": "text", "text": system}]
|
||||
else:
|
||||
payload["system"] = [billing_block]
|
||||
|
||||
# Add metadata.user_id required by the API for OAuth billing validation
|
||||
if not isinstance(payload.get("metadata"), dict):
|
||||
payload["metadata"] = {}
|
||||
if "user_id" not in payload["metadata"]:
|
||||
# Generate a stable device_id from the machine's hostname
|
||||
hostname = socket.gethostname()
|
||||
device_id = hashlib.sha256(f"deerflow-{hostname}".encode()).hexdigest()
|
||||
session_id = str(uuid.uuid4())
|
||||
payload["metadata"]["user_id"] = json.dumps(
|
||||
{
|
||||
"device_id": device_id,
|
||||
"account_uuid": "deerflow",
|
||||
"session_id": session_id,
|
||||
}
|
||||
)
|
||||
|
||||
def _apply_prompt_caching(self, payload: dict) -> None:
|
||||
"""Apply ephemeral cache_control to system and recent messages."""
|
||||
# Cache system messages
|
||||
system = payload.get("system")
|
||||
if system and isinstance(system, list):
|
||||
for block in system:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
block["cache_control"] = {"type": "ephemeral"}
|
||||
elif system and isinstance(system, str):
|
||||
payload["system"] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": system,
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
]
|
||||
|
||||
# Cache recent messages
|
||||
messages = payload.get("messages", [])
|
||||
cache_start = max(0, len(messages) - self.prompt_cache_size)
|
||||
for i in range(cache_start, len(messages)):
|
||||
msg = messages[i]
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict):
|
||||
block["cache_control"] = {"type": "ephemeral"}
|
||||
elif isinstance(content, str) and content:
|
||||
msg["content"] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": content,
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
]
|
||||
|
||||
# Cache the last tool definition
|
||||
tools = payload.get("tools", [])
|
||||
if tools and isinstance(tools[-1], dict):
|
||||
tools[-1]["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
def _apply_thinking_budget(self, payload: dict) -> None:
|
||||
"""Auto-allocate thinking budget (80% of max_tokens)."""
|
||||
thinking = payload.get("thinking")
|
||||
if not thinking or not isinstance(thinking, dict):
|
||||
return
|
||||
if thinking.get("type") != "enabled":
|
||||
return
|
||||
if thinking.get("budget_tokens"):
|
||||
return
|
||||
|
||||
max_tokens = payload.get("max_tokens", 8192)
|
||||
thinking["budget_tokens"] = int(max_tokens * THINKING_BUDGET_RATIO)
|
||||
|
||||
@staticmethod
|
||||
def _strip_cache_control(payload: dict) -> None:
|
||||
"""Remove cache_control markers before OAuth requests reach Anthropic."""
|
||||
for section in ("system", "messages"):
|
||||
items = payload.get(section)
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item.pop("cache_control", None)
|
||||
content = item.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict):
|
||||
block.pop("cache_control", None)
|
||||
|
||||
tools = payload.get("tools")
|
||||
if isinstance(tools, list):
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
tool.pop("cache_control", None)
|
||||
|
||||
def _create(self, payload: dict) -> Any:
|
||||
if self._is_oauth:
|
||||
self._strip_cache_control(payload)
|
||||
return super()._create(payload)
|
||||
|
||||
async def _acreate(self, payload: dict) -> Any:
|
||||
if self._is_oauth:
|
||||
self._strip_cache_control(payload)
|
||||
return await super()._acreate(payload)
|
||||
|
||||
def _generate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
|
||||
"""Override with OAuth patching and retry logic."""
|
||||
if self._is_oauth:
|
||||
self._patch_client_oauth(self._client)
|
||||
|
||||
last_error = None
|
||||
for attempt in range(1, self.retry_max_attempts + 1):
|
||||
try:
|
||||
return super()._generate(messages, stop=stop, **kwargs)
|
||||
except anthropic.RateLimitError as e:
|
||||
last_error = e
|
||||
if attempt >= self.retry_max_attempts:
|
||||
raise
|
||||
wait_ms = self._calc_backoff_ms(attempt, e)
|
||||
logger.warning(f"Rate limited, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
||||
time.sleep(wait_ms / 1000)
|
||||
except anthropic.InternalServerError as e:
|
||||
last_error = e
|
||||
if attempt >= self.retry_max_attempts:
|
||||
raise
|
||||
wait_ms = self._calc_backoff_ms(attempt, e)
|
||||
logger.warning(f"Server error, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
||||
time.sleep(wait_ms / 1000)
|
||||
raise last_error
|
||||
|
||||
async def _agenerate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
|
||||
"""Async override with OAuth patching and retry logic."""
|
||||
import asyncio
|
||||
|
||||
if self._is_oauth:
|
||||
self._patch_client_oauth(self._async_client)
|
||||
|
||||
last_error = None
|
||||
for attempt in range(1, self.retry_max_attempts + 1):
|
||||
try:
|
||||
return await super()._agenerate(messages, stop=stop, **kwargs)
|
||||
except anthropic.RateLimitError as e:
|
||||
last_error = e
|
||||
if attempt >= self.retry_max_attempts:
|
||||
raise
|
||||
wait_ms = self._calc_backoff_ms(attempt, e)
|
||||
logger.warning(f"Rate limited, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
||||
await asyncio.sleep(wait_ms / 1000)
|
||||
except anthropic.InternalServerError as e:
|
||||
last_error = e
|
||||
if attempt >= self.retry_max_attempts:
|
||||
raise
|
||||
wait_ms = self._calc_backoff_ms(attempt, e)
|
||||
logger.warning(f"Server error, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
||||
await asyncio.sleep(wait_ms / 1000)
|
||||
raise last_error
|
||||
|
||||
@staticmethod
|
||||
def _calc_backoff_ms(attempt: int, error: Exception) -> int:
|
||||
"""Exponential backoff with a fixed 20% buffer."""
|
||||
backoff_ms = 2000 * (1 << (attempt - 1))
|
||||
jitter_ms = int(backoff_ms * 0.2)
|
||||
total_ms = backoff_ms + jitter_ms
|
||||
|
||||
if hasattr(error, "response") and error.response is not None:
|
||||
retry_after = error.response.headers.get("Retry-After")
|
||||
if retry_after:
|
||||
try:
|
||||
total_ms = int(retry_after) * 1000
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return total_ms
|
||||
@@ -0,0 +1,219 @@
|
||||
"""Auto-load credentials from Claude Code CLI and Codex CLI.
|
||||
|
||||
Implements two credential strategies:
|
||||
1. Claude Code OAuth token from explicit env vars or an exported credentials file
|
||||
- Uses Authorization: Bearer header (NOT x-api-key)
|
||||
- Requires anthropic-beta: oauth-2025-04-20,claude-code-20250219
|
||||
- Supports $CLAUDE_CODE_OAUTH_TOKEN, $CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR, and $ANTHROPIC_AUTH_TOKEN
|
||||
- Override path with $CLAUDE_CODE_CREDENTIALS_PATH
|
||||
2. Codex CLI token from ~/.codex/auth.json
|
||||
- Uses chatgpt.com/backend-api/codex/responses endpoint
|
||||
- Supports both legacy top-level tokens and current nested tokens shape
|
||||
- Override path with $CODEX_AUTH_PATH
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Required beta headers for Claude Code OAuth tokens
|
||||
OAUTH_ANTHROPIC_BETAS = "oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14"
|
||||
|
||||
|
||||
def is_oauth_token(token: str) -> bool:
|
||||
"""Check if a token is a Claude Code OAuth token (not a standard API key)."""
|
||||
return isinstance(token, str) and "sk-ant-oat" in token
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClaudeCodeCredential:
|
||||
"""Claude Code CLI OAuth credential."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str = ""
|
||||
expires_at: int = 0
|
||||
source: str = ""
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
if self.expires_at <= 0:
|
||||
return False
|
||||
return time.time() * 1000 > self.expires_at - 60_000 # 1 min buffer
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodexCliCredential:
|
||||
"""Codex CLI credential."""
|
||||
|
||||
access_token: str
|
||||
account_id: str = ""
|
||||
source: str = ""
|
||||
|
||||
|
||||
def _resolve_credential_path(env_var: str, default_relative_path: str) -> Path:
|
||||
configured_path = os.getenv(env_var)
|
||||
if configured_path:
|
||||
return Path(configured_path).expanduser()
|
||||
return _home_dir() / default_relative_path
|
||||
|
||||
|
||||
def _home_dir() -> Path:
|
||||
home = os.getenv("HOME")
|
||||
if home:
|
||||
return Path(home).expanduser()
|
||||
return Path.home()
|
||||
|
||||
|
||||
def _load_json_file(path: Path, label: str) -> dict[str, Any] | None:
|
||||
if not path.exists():
|
||||
logger.debug(f"{label} not found: {path}")
|
||||
return None
|
||||
if path.is_dir():
|
||||
logger.warning(f"{label} path is a directory, expected a file: {path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(path.read_text())
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning(f"Failed to read {label}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _read_secret_from_file_descriptor(env_var: str) -> str | None:
|
||||
fd_value = os.getenv(env_var)
|
||||
if not fd_value:
|
||||
return None
|
||||
|
||||
try:
|
||||
fd = int(fd_value)
|
||||
except ValueError:
|
||||
logger.warning(f"{env_var} must be an integer file descriptor, got: {fd_value}")
|
||||
return None
|
||||
|
||||
try:
|
||||
secret = os.read(fd, 1024 * 1024).decode().strip()
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to read {env_var}: {e}")
|
||||
return None
|
||||
|
||||
return secret or None
|
||||
|
||||
|
||||
def _credential_from_direct_token(access_token: str, source: str) -> ClaudeCodeCredential | None:
|
||||
token = access_token.strip()
|
||||
if not token:
|
||||
return None
|
||||
return ClaudeCodeCredential(access_token=token, source=source)
|
||||
|
||||
|
||||
def _iter_claude_code_credential_paths() -> list[Path]:
|
||||
paths: list[Path] = []
|
||||
override_path = os.getenv("CLAUDE_CODE_CREDENTIALS_PATH")
|
||||
if override_path:
|
||||
paths.append(Path(override_path).expanduser())
|
||||
|
||||
default_path = _home_dir() / ".claude/.credentials.json"
|
||||
if not paths or paths[-1] != default_path:
|
||||
paths.append(default_path)
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
def _extract_claude_code_credential(data: dict[str, Any], source: str) -> ClaudeCodeCredential | None:
|
||||
oauth = data.get("claudeAiOauth", {})
|
||||
access_token = oauth.get("accessToken", "")
|
||||
if not access_token:
|
||||
logger.debug("Claude Code credentials container exists but no accessToken found")
|
||||
return None
|
||||
|
||||
cred = ClaudeCodeCredential(
|
||||
access_token=access_token,
|
||||
refresh_token=oauth.get("refreshToken", ""),
|
||||
expires_at=oauth.get("expiresAt", 0),
|
||||
source=source,
|
||||
)
|
||||
|
||||
if cred.is_expired:
|
||||
logger.warning("Claude Code OAuth token is expired. Run 'claude' to refresh.")
|
||||
return None
|
||||
|
||||
return cred
|
||||
|
||||
|
||||
def load_claude_code_credential() -> ClaudeCodeCredential | None:
|
||||
"""Load OAuth credential from explicit Claude Code handoff sources.
|
||||
|
||||
Lookup order:
|
||||
1. $CLAUDE_CODE_OAUTH_TOKEN or $ANTHROPIC_AUTH_TOKEN
|
||||
2. $CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR
|
||||
3. $CLAUDE_CODE_CREDENTIALS_PATH
|
||||
4. ~/.claude/.credentials.json
|
||||
|
||||
Exported credentials files contain:
|
||||
{
|
||||
"claudeAiOauth": {
|
||||
"accessToken": "sk-ant-oat01-...",
|
||||
"refreshToken": "sk-ant-ort01-...",
|
||||
"expiresAt": 1773430695128,
|
||||
"scopes": ["user:inference", ...],
|
||||
...
|
||||
}
|
||||
}
|
||||
"""
|
||||
direct_token = os.getenv("CLAUDE_CODE_OAUTH_TOKEN") or os.getenv("ANTHROPIC_AUTH_TOKEN")
|
||||
if direct_token:
|
||||
cred = _credential_from_direct_token(direct_token, "claude-cli-env")
|
||||
if cred:
|
||||
logger.info("Loaded Claude Code OAuth credential from environment")
|
||||
return cred
|
||||
|
||||
fd_token = _read_secret_from_file_descriptor("CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR")
|
||||
if fd_token:
|
||||
cred = _credential_from_direct_token(fd_token, "claude-cli-fd")
|
||||
if cred:
|
||||
logger.info("Loaded Claude Code OAuth credential from file descriptor")
|
||||
return cred
|
||||
|
||||
override_path = os.getenv("CLAUDE_CODE_CREDENTIALS_PATH")
|
||||
override_path_obj = Path(override_path).expanduser() if override_path else None
|
||||
for cred_path in _iter_claude_code_credential_paths():
|
||||
data = _load_json_file(cred_path, "Claude Code credentials")
|
||||
if data is None:
|
||||
continue
|
||||
cred = _extract_claude_code_credential(data, "claude-cli-file")
|
||||
if cred:
|
||||
source_label = "override path" if override_path_obj is not None and cred_path == override_path_obj else "plaintext file"
|
||||
logger.info(f"Loaded Claude Code OAuth credential from {source_label} (expires_at={cred.expires_at})")
|
||||
return cred
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def load_codex_cli_credential() -> CodexCliCredential | None:
|
||||
"""Load credential from Codex CLI (~/.codex/auth.json)."""
|
||||
cred_path = _resolve_credential_path("CODEX_AUTH_PATH", ".codex/auth.json")
|
||||
data = _load_json_file(cred_path, "Codex CLI credentials")
|
||||
if data is None:
|
||||
return None
|
||||
tokens = data.get("tokens", {})
|
||||
if not isinstance(tokens, dict):
|
||||
tokens = {}
|
||||
|
||||
access_token = data.get("access_token") or data.get("token") or tokens.get("access_token", "")
|
||||
account_id = data.get("account_id") or tokens.get("account_id", "")
|
||||
if not access_token:
|
||||
logger.debug("Codex CLI credentials file exists but no token found")
|
||||
return None
|
||||
|
||||
logger.info("Loaded Codex CLI credential")
|
||||
return CodexCliCredential(
|
||||
access_token=access_token,
|
||||
account_id=account_id,
|
||||
source="codex-cli",
|
||||
)
|
||||
123
deer-flow/backend/packages/harness/deerflow/models/factory.py
Normal file
123
deer-flow/backend/packages/harness/deerflow/models/factory.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import logging
|
||||
|
||||
from langchain.chat_models import BaseChatModel
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.reflection import resolve_class
|
||||
from deerflow.tracing import build_tracing_callbacks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _deep_merge_dicts(base: dict | None, override: dict) -> dict:
|
||||
"""Recursively merge two dictionaries without mutating the inputs."""
|
||||
merged = dict(base or {})
|
||||
for key, value in override.items():
|
||||
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
||||
merged[key] = _deep_merge_dicts(merged[key], value)
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
|
||||
def _vllm_disable_chat_template_kwargs(chat_template_kwargs: dict) -> dict:
|
||||
"""Build the disable payload for vLLM/Qwen chat template kwargs."""
|
||||
disable_kwargs: dict[str, bool] = {}
|
||||
if "thinking" in chat_template_kwargs:
|
||||
disable_kwargs["thinking"] = False
|
||||
if "enable_thinking" in chat_template_kwargs:
|
||||
disable_kwargs["enable_thinking"] = False
|
||||
return disable_kwargs
|
||||
|
||||
|
||||
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
|
||||
"""Create a chat model instance from the config.
|
||||
|
||||
Args:
|
||||
name: The name of the model to create. If None, the first model in the config will be used.
|
||||
|
||||
Returns:
|
||||
A chat model instance.
|
||||
"""
|
||||
config = get_app_config()
|
||||
if name is None:
|
||||
name = config.models[0].name
|
||||
model_config = config.get_model_config(name)
|
||||
if model_config is None:
|
||||
raise ValueError(f"Model {name} not found in config") from None
|
||||
model_class = resolve_class(model_config.use, BaseChatModel)
|
||||
model_settings_from_config = model_config.model_dump(
|
||||
exclude_none=True,
|
||||
exclude={
|
||||
"use",
|
||||
"name",
|
||||
"display_name",
|
||||
"description",
|
||||
"supports_thinking",
|
||||
"supports_reasoning_effort",
|
||||
"when_thinking_enabled",
|
||||
"when_thinking_disabled",
|
||||
"thinking",
|
||||
"supports_vision",
|
||||
},
|
||||
)
|
||||
# Compute effective when_thinking_enabled by merging in the `thinking` shortcut field.
|
||||
# The `thinking` shortcut is equivalent to setting when_thinking_enabled["thinking"].
|
||||
has_thinking_settings = (model_config.when_thinking_enabled is not None) or (model_config.thinking is not None)
|
||||
effective_wte: dict = dict(model_config.when_thinking_enabled) if model_config.when_thinking_enabled else {}
|
||||
if model_config.thinking is not None:
|
||||
merged_thinking = {**(effective_wte.get("thinking") or {}), **model_config.thinking}
|
||||
effective_wte = {**effective_wte, "thinking": merged_thinking}
|
||||
if thinking_enabled and has_thinking_settings:
|
||||
if not model_config.supports_thinking:
|
||||
raise ValueError(f"Model {name} does not support thinking. Set `supports_thinking` to true in the `config.yaml` to enable thinking.") from None
|
||||
if effective_wte:
|
||||
model_settings_from_config.update(effective_wte)
|
||||
if not thinking_enabled:
|
||||
if model_config.when_thinking_disabled is not None:
|
||||
# User-provided disable settings take full precedence
|
||||
model_settings_from_config.update(model_config.when_thinking_disabled)
|
||||
elif has_thinking_settings and effective_wte.get("extra_body", {}).get("thinking", {}).get("type"):
|
||||
# OpenAI-compatible gateway: thinking is nested under extra_body
|
||||
model_settings_from_config["extra_body"] = _deep_merge_dicts(
|
||||
model_settings_from_config.get("extra_body"),
|
||||
{"thinking": {"type": "disabled"}},
|
||||
)
|
||||
model_settings_from_config["reasoning_effort"] = "minimal"
|
||||
elif has_thinking_settings and (disable_chat_template_kwargs := _vllm_disable_chat_template_kwargs(effective_wte.get("extra_body", {}).get("chat_template_kwargs") or {})):
|
||||
# vLLM uses chat template kwargs to switch thinking on/off.
|
||||
model_settings_from_config["extra_body"] = _deep_merge_dicts(
|
||||
model_settings_from_config.get("extra_body"),
|
||||
{"chat_template_kwargs": disable_chat_template_kwargs},
|
||||
)
|
||||
elif has_thinking_settings and effective_wte.get("thinking", {}).get("type"):
|
||||
# Native langchain_anthropic: thinking is a direct constructor parameter
|
||||
model_settings_from_config["thinking"] = {"type": "disabled"}
|
||||
if not model_config.supports_reasoning_effort:
|
||||
kwargs.pop("reasoning_effort", None)
|
||||
model_settings_from_config.pop("reasoning_effort", None)
|
||||
|
||||
# For Codex Responses API models: map thinking mode to reasoning_effort
|
||||
from deerflow.models.openai_codex_provider import CodexChatModel
|
||||
|
||||
if issubclass(model_class, CodexChatModel):
|
||||
# The ChatGPT Codex endpoint currently rejects max_tokens/max_output_tokens.
|
||||
model_settings_from_config.pop("max_tokens", None)
|
||||
|
||||
# Use explicit reasoning_effort from frontend if provided (low/medium/high)
|
||||
explicit_effort = kwargs.pop("reasoning_effort", None)
|
||||
if not thinking_enabled:
|
||||
model_settings_from_config["reasoning_effort"] = "none"
|
||||
elif explicit_effort and explicit_effort in ("low", "medium", "high", "xhigh"):
|
||||
model_settings_from_config["reasoning_effort"] = explicit_effort
|
||||
elif "reasoning_effort" not in model_settings_from_config:
|
||||
model_settings_from_config["reasoning_effort"] = "medium"
|
||||
|
||||
model_instance = model_class(**{**model_settings_from_config, **kwargs})
|
||||
|
||||
callbacks = build_tracing_callbacks()
|
||||
if callbacks:
|
||||
existing_callbacks = model_instance.callbacks or []
|
||||
model_instance.callbacks = [*existing_callbacks, *callbacks]
|
||||
logger.debug(f"Tracing attached to model '{name}' with providers={len(callbacks)}")
|
||||
return model_instance
|
||||
@@ -0,0 +1,430 @@
|
||||
"""Custom OpenAI Codex provider using ChatGPT Codex Responses API.
|
||||
|
||||
Uses Codex CLI OAuth tokens with chatgpt.com/backend-api/codex/responses endpoint.
|
||||
This is the same endpoint that the Codex CLI uses internally.
|
||||
|
||||
Supports:
|
||||
- Auto-load credentials from ~/.codex/auth.json
|
||||
- Responses API format (not Chat Completions)
|
||||
- Tool calling
|
||||
- Streaming (required by the endpoint)
|
||||
- Retry with exponential backoff
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from deerflow.models.credential_loader import CodexCliCredential, load_codex_cli_credential
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
MAX_RETRIES = 3
|
||||
|
||||
|
||||
class CodexChatModel(BaseChatModel):
|
||||
"""LangChain chat model using ChatGPT Codex Responses API.
|
||||
|
||||
Config example:
|
||||
- name: gpt-5.4
|
||||
use: deerflow.models.openai_codex_provider:CodexChatModel
|
||||
model: gpt-5.4
|
||||
reasoning_effort: medium
|
||||
"""
|
||||
|
||||
model: str = "gpt-5.4"
|
||||
reasoning_effort: str = "medium"
|
||||
retry_max_attempts: int = MAX_RETRIES
|
||||
_access_token: str = ""
|
||||
_account_id: str = ""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "codex-responses"
|
||||
|
||||
def _validate_retry_config(self) -> None:
|
||||
if self.retry_max_attempts < 1:
|
||||
raise ValueError("retry_max_attempts must be >= 1")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Auto-load Codex CLI credentials."""
|
||||
self._validate_retry_config()
|
||||
|
||||
cred = self._load_codex_auth()
|
||||
if cred:
|
||||
self._access_token = cred.access_token
|
||||
self._account_id = cred.account_id
|
||||
logger.info(f"Using Codex CLI credential (account: {self._account_id[:8]}...)")
|
||||
else:
|
||||
raise ValueError("Codex CLI credential not found. Expected ~/.codex/auth.json or CODEX_AUTH_PATH.")
|
||||
|
||||
super().model_post_init(__context)
|
||||
|
||||
def _load_codex_auth(self) -> CodexCliCredential | None:
|
||||
"""Load access_token and account_id from Codex CLI auth."""
|
||||
return load_codex_cli_credential()
|
||||
|
||||
@classmethod
|
||||
def _normalize_content(cls, content: Any) -> str:
|
||||
"""Flatten LangChain content blocks into plain text for Codex."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = [cls._normalize_content(item) for item in content]
|
||||
return "\n".join(part for part in parts if part)
|
||||
|
||||
if isinstance(content, dict):
|
||||
for key in ("text", "output"):
|
||||
value = content.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
nested_content = content.get("content")
|
||||
if nested_content is not None:
|
||||
return cls._normalize_content(nested_content)
|
||||
try:
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(content)
|
||||
|
||||
try:
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(content)
|
||||
|
||||
def _convert_messages(self, messages: list[BaseMessage]) -> tuple[str, list[dict]]:
|
||||
"""Convert LangChain messages to Responses API format.
|
||||
|
||||
Returns (instructions, input_items).
|
||||
"""
|
||||
instructions_parts: list[str] = []
|
||||
input_items = []
|
||||
|
||||
for msg in messages:
|
||||
if isinstance(msg, SystemMessage):
|
||||
content = self._normalize_content(msg.content)
|
||||
if content:
|
||||
instructions_parts.append(content)
|
||||
elif isinstance(msg, HumanMessage):
|
||||
content = self._normalize_content(msg.content)
|
||||
input_items.append({"role": "user", "content": content})
|
||||
elif isinstance(msg, AIMessage):
|
||||
if msg.content:
|
||||
content = self._normalize_content(msg.content)
|
||||
input_items.append({"role": "assistant", "content": content})
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call",
|
||||
"name": tc["name"],
|
||||
"arguments": json.dumps(tc["args"]) if isinstance(tc["args"], dict) else tc["args"],
|
||||
"call_id": tc["id"],
|
||||
}
|
||||
)
|
||||
elif isinstance(msg, ToolMessage):
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": msg.tool_call_id,
|
||||
"output": self._normalize_content(msg.content),
|
||||
}
|
||||
)
|
||||
|
||||
instructions = "\n\n".join(instructions_parts) or "You are a helpful assistant."
|
||||
|
||||
return instructions, input_items
|
||||
|
||||
def _convert_tools(self, tools: list[dict]) -> list[dict]:
|
||||
"""Convert LangChain tool format to Responses API format."""
|
||||
responses_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function" and "function" in tool:
|
||||
fn = tool["function"]
|
||||
responses_tools.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": fn["name"],
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
elif "name" in tool:
|
||||
responses_tools.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": tool["name"],
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": tool.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
return responses_tools
|
||||
|
||||
def _call_codex_api(self, messages: list[BaseMessage], tools: list[dict] | None = None) -> dict:
|
||||
"""Call the Codex Responses API and return the completed response."""
|
||||
instructions, input_items = self._convert_messages(messages)
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"instructions": instructions,
|
||||
"input": input_items,
|
||||
"store": False,
|
||||
"stream": True,
|
||||
"reasoning": {"effort": self.reasoning_effort, "summary": "detailed"} if self.reasoning_effort != "none" else {"effort": "none"},
|
||||
}
|
||||
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools(tools)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._access_token}",
|
||||
"ChatGPT-Account-ID": self._account_id,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "text/event-stream",
|
||||
"originator": "codex_cli_rs",
|
||||
}
|
||||
|
||||
last_error = None
|
||||
for attempt in range(1, self.retry_max_attempts + 1):
|
||||
try:
|
||||
return self._stream_response(headers, payload)
|
||||
except httpx.HTTPStatusError as e:
|
||||
last_error = e
|
||||
if e.response.status_code in (429, 500, 529):
|
||||
if attempt >= self.retry_max_attempts:
|
||||
raise
|
||||
wait_ms = 2000 * (1 << (attempt - 1))
|
||||
logger.warning(f"Codex API error {e.response.status_code}, retrying {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
||||
time.sleep(wait_ms / 1000)
|
||||
else:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
raise last_error
|
||||
|
||||
def _stream_response(self, headers: dict, payload: dict) -> dict:
|
||||
"""Stream SSE from Codex API and collect the final response."""
|
||||
completed_response = None
|
||||
streamed_output_items: dict[int, dict[str, Any]] = {}
|
||||
|
||||
with httpx.Client(timeout=300) as client:
|
||||
with client.stream("POST", f"{CODEX_BASE_URL}/responses", headers=headers, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
for line in resp.iter_lines():
|
||||
data = self._parse_sse_data_line(line)
|
||||
if not data:
|
||||
continue
|
||||
|
||||
event_type = data.get("type")
|
||||
if event_type == "response.output_item.done":
|
||||
output_index = data.get("output_index")
|
||||
output_item = data.get("item")
|
||||
if isinstance(output_index, int) and isinstance(output_item, dict):
|
||||
streamed_output_items[output_index] = output_item
|
||||
elif event_type == "response.completed":
|
||||
completed_response = data["response"]
|
||||
|
||||
if not completed_response:
|
||||
raise RuntimeError("Codex API stream ended without response.completed event")
|
||||
|
||||
# ChatGPT Codex can emit the final assistant content only in stream events.
|
||||
# When response.completed arrives, response.output may still be empty.
|
||||
if streamed_output_items:
|
||||
merged_output = []
|
||||
response_output = completed_response.get("output")
|
||||
if isinstance(response_output, list):
|
||||
merged_output = list(response_output)
|
||||
|
||||
max_index = max(max(streamed_output_items), len(merged_output) - 1)
|
||||
if max_index >= 0 and len(merged_output) <= max_index:
|
||||
merged_output.extend([None] * (max_index + 1 - len(merged_output)))
|
||||
|
||||
for output_index, output_item in streamed_output_items.items():
|
||||
existing_item = merged_output[output_index]
|
||||
if not isinstance(existing_item, dict):
|
||||
merged_output[output_index] = output_item
|
||||
|
||||
completed_response = dict(completed_response)
|
||||
completed_response["output"] = [item for item in merged_output if isinstance(item, dict)]
|
||||
|
||||
return completed_response
|
||||
|
||||
@staticmethod
|
||||
def _parse_sse_data_line(line: str) -> dict[str, Any] | None:
|
||||
"""Parse a data line from the SSE stream, skipping terminal markers."""
|
||||
if not line.startswith("data:"):
|
||||
return None
|
||||
|
||||
raw_data = line[5:].strip()
|
||||
if not raw_data or raw_data == "[DONE]":
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.loads(raw_data)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug(f"Skipping non-JSON Codex SSE frame: {raw_data}")
|
||||
return None
|
||||
|
||||
return data if isinstance(data, dict) else None
|
||||
|
||||
def _parse_tool_call_arguments(self, output_item: dict[str, Any]) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
|
||||
"""Parse function-call arguments, surfacing malformed payloads safely."""
|
||||
raw_arguments = output_item.get("arguments", "{}")
|
||||
if isinstance(raw_arguments, dict):
|
||||
return raw_arguments, None
|
||||
|
||||
normalized_arguments = raw_arguments or "{}"
|
||||
try:
|
||||
parsed_arguments = json.loads(normalized_arguments)
|
||||
except (TypeError, json.JSONDecodeError) as exc:
|
||||
return None, {
|
||||
"type": "invalid_tool_call",
|
||||
"name": output_item.get("name"),
|
||||
"args": str(raw_arguments),
|
||||
"id": output_item.get("call_id"),
|
||||
"error": f"Failed to parse tool arguments: {exc}",
|
||||
}
|
||||
|
||||
if not isinstance(parsed_arguments, dict):
|
||||
return None, {
|
||||
"type": "invalid_tool_call",
|
||||
"name": output_item.get("name"),
|
||||
"args": str(raw_arguments),
|
||||
"id": output_item.get("call_id"),
|
||||
"error": "Tool arguments must decode to a JSON object.",
|
||||
}
|
||||
|
||||
return parsed_arguments, None
|
||||
|
||||
def _parse_response(self, response: dict) -> ChatResult:
|
||||
"""Parse Codex Responses API response into LangChain ChatResult."""
|
||||
content = ""
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
reasoning_content = ""
|
||||
|
||||
for output_item in response.get("output", []):
|
||||
if output_item.get("type") == "reasoning":
|
||||
# Extract reasoning summary text
|
||||
for summary_item in output_item.get("summary", []):
|
||||
if isinstance(summary_item, dict) and summary_item.get("type") == "summary_text":
|
||||
reasoning_content += summary_item.get("text", "")
|
||||
elif isinstance(summary_item, str):
|
||||
reasoning_content += summary_item
|
||||
elif output_item.get("type") == "message":
|
||||
for part in output_item.get("content", []):
|
||||
if part.get("type") == "output_text":
|
||||
content += part.get("text", "")
|
||||
elif output_item.get("type") == "function_call":
|
||||
parsed_arguments, invalid_tool_call = self._parse_tool_call_arguments(output_item)
|
||||
if invalid_tool_call:
|
||||
invalid_tool_calls.append(invalid_tool_call)
|
||||
continue
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"name": output_item["name"],
|
||||
"args": parsed_arguments or {},
|
||||
"id": output_item.get("call_id", ""),
|
||||
"type": "tool_call",
|
||||
}
|
||||
)
|
||||
|
||||
usage = response.get("usage", {})
|
||||
additional_kwargs = {}
|
||||
if reasoning_content:
|
||||
additional_kwargs["reasoning_content"] = reasoning_content
|
||||
|
||||
message = AIMessage(
|
||||
content=content,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
additional_kwargs=additional_kwargs,
|
||||
response_metadata={
|
||||
"model": response.get("model", self.model),
|
||||
"usage": usage,
|
||||
},
|
||||
)
|
||||
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=message)],
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": usage.get("input_tokens", 0),
|
||||
"completion_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
},
|
||||
"model_name": response.get("model", self.model),
|
||||
},
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate a response using Codex Responses API."""
|
||||
tools = kwargs.get("tools", None)
|
||||
response = self._call_codex_api(messages, tools=tools)
|
||||
return self._parse_response(response)
|
||||
|
||||
def bind_tools(self, tools: list, **kwargs: Any) -> Any:
|
||||
"""Bind tools for function calling."""
|
||||
from langchain_core.runnables import RunnableBinding
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
|
||||
formatted_tools = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, BaseTool):
|
||||
try:
|
||||
fn = convert_to_openai_function(tool)
|
||||
formatted_tools.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": fn["name"],
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
formatted_tools.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
}
|
||||
)
|
||||
elif isinstance(tool, dict):
|
||||
if "function" in tool:
|
||||
fn = tool["function"]
|
||||
formatted_tools.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": fn["name"],
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
else:
|
||||
formatted_tools.append(tool)
|
||||
|
||||
return RunnableBinding(bound=self, kwargs={"tools": formatted_tools}, **kwargs)
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Patched ChatDeepSeek that preserves reasoning_content in multi-turn conversations.
|
||||
|
||||
This module provides a patched version of ChatDeepSeek that properly handles
|
||||
reasoning_content when sending messages back to the API. The original implementation
|
||||
stores reasoning_content in additional_kwargs but doesn't include it when making
|
||||
subsequent API calls, which causes errors with APIs that require reasoning_content
|
||||
on all assistant messages when thinking mode is enabled.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
|
||||
class PatchedChatDeepSeek(ChatDeepSeek):
|
||||
"""ChatDeepSeek with proper reasoning_content preservation.
|
||||
|
||||
When using thinking/reasoning enabled models, the API expects reasoning_content
|
||||
to be present on ALL assistant messages in multi-turn conversations. This patched
|
||||
version ensures reasoning_content from additional_kwargs is included in the
|
||||
request payload.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> dict[str, str]:
|
||||
return {"api_key": "DEEPSEEK_API_KEY", "openai_api_key": "DEEPSEEK_API_KEY"}
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Get request payload with reasoning_content preserved.
|
||||
|
||||
Overrides the parent method to inject reasoning_content from
|
||||
additional_kwargs into assistant messages in the payload.
|
||||
"""
|
||||
# Get the original messages before conversion
|
||||
original_messages = self._convert_input(input_).to_messages()
|
||||
|
||||
# Call parent to get the base payload
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
|
||||
# Match payload messages with original messages to restore reasoning_content
|
||||
payload_messages = payload.get("messages", [])
|
||||
|
||||
# The payload messages and original messages should be in the same order
|
||||
# Iterate through both and match by position
|
||||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
reasoning_content = orig_msg.additional_kwargs.get("reasoning_content")
|
||||
if reasoning_content is not None:
|
||||
payload_msg["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
# Fallback: match by counting assistant messages
|
||||
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
||||
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
|
||||
|
||||
for (idx, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
|
||||
reasoning_content = ai_msg.additional_kwargs.get("reasoning_content")
|
||||
if reasoning_content is not None:
|
||||
payload_messages[idx]["reasoning_content"] = reasoning_content
|
||||
|
||||
return payload
|
||||
@@ -0,0 +1,220 @@
|
||||
"""Patched ChatOpenAI adapter for MiniMax reasoning output.
|
||||
|
||||
MiniMax's OpenAI-compatible chat completions API can return structured
|
||||
``reasoning_details`` when ``extra_body.reasoning_split=true`` is enabled.
|
||||
``langchain_openai.ChatOpenAI`` currently ignores that field, so DeerFlow's
|
||||
frontend never receives reasoning content in the shape it expects.
|
||||
|
||||
This adapter preserves ``reasoning_split`` in the request payload and maps the
|
||||
provider-specific reasoning field into ``additional_kwargs.reasoning_content``,
|
||||
which DeerFlow already understands.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models.base import (
|
||||
_convert_delta_to_message_chunk,
|
||||
_create_usage_metadata,
|
||||
)
|
||||
|
||||
_THINK_TAG_RE = re.compile(r"<think>\s*(.*?)\s*</think>", re.DOTALL)
|
||||
|
||||
|
||||
def _extract_reasoning_text(
|
||||
reasoning_details: Any,
|
||||
*,
|
||||
strip_parts: bool = True,
|
||||
) -> str | None:
|
||||
if not isinstance(reasoning_details, list):
|
||||
return None
|
||||
|
||||
parts: list[str] = []
|
||||
for item in reasoning_details:
|
||||
if not isinstance(item, Mapping):
|
||||
continue
|
||||
text = item.get("text")
|
||||
if isinstance(text, str):
|
||||
normalized = text.strip() if strip_parts else text
|
||||
if normalized.strip():
|
||||
parts.append(normalized)
|
||||
|
||||
return "\n\n".join(parts) if parts else None
|
||||
|
||||
|
||||
def _strip_inline_think_tags(content: str) -> tuple[str, str | None]:
|
||||
reasoning_parts: list[str] = []
|
||||
|
||||
def _replace(match: re.Match[str]) -> str:
|
||||
reasoning = match.group(1).strip()
|
||||
if reasoning:
|
||||
reasoning_parts.append(reasoning)
|
||||
return ""
|
||||
|
||||
cleaned = _THINK_TAG_RE.sub(_replace, content).strip()
|
||||
reasoning = "\n\n".join(reasoning_parts) if reasoning_parts else None
|
||||
return cleaned, reasoning
|
||||
|
||||
|
||||
def _merge_reasoning(*values: str | None) -> str | None:
|
||||
merged: list[str] = []
|
||||
for value in values:
|
||||
if not value:
|
||||
continue
|
||||
normalized = value.strip()
|
||||
if normalized and normalized not in merged:
|
||||
merged.append(normalized)
|
||||
return "\n\n".join(merged) if merged else None
|
||||
|
||||
|
||||
def _with_reasoning_content(
|
||||
message: AIMessage | AIMessageChunk,
|
||||
reasoning: str | None,
|
||||
*,
|
||||
preserve_whitespace: bool = False,
|
||||
):
|
||||
if not reasoning:
|
||||
return message
|
||||
|
||||
additional_kwargs = dict(message.additional_kwargs)
|
||||
if preserve_whitespace:
|
||||
existing = additional_kwargs.get("reasoning_content")
|
||||
additional_kwargs["reasoning_content"] = f"{existing}{reasoning}" if isinstance(existing, str) else reasoning
|
||||
else:
|
||||
additional_kwargs["reasoning_content"] = _merge_reasoning(
|
||||
additional_kwargs.get("reasoning_content"),
|
||||
reasoning,
|
||||
)
|
||||
return message.model_copy(update={"additional_kwargs": additional_kwargs})
|
||||
|
||||
|
||||
class PatchedChatMiniMax(ChatOpenAI):
|
||||
"""ChatOpenAI adapter that preserves MiniMax reasoning output."""
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
extra_body = payload.get("extra_body")
|
||||
if isinstance(extra_body, dict):
|
||||
payload["extra_body"] = {
|
||||
**extra_body,
|
||||
"reasoning_split": True,
|
||||
}
|
||||
else:
|
||||
payload["extra_body"] = {"reasoning_split": True}
|
||||
return payload
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
chunk: dict,
|
||||
default_chunk_class: type,
|
||||
base_generation_info: dict | None,
|
||||
) -> ChatGenerationChunk | None:
|
||||
if chunk.get("type") == "content.delta":
|
||||
return None
|
||||
|
||||
token_usage = chunk.get("usage")
|
||||
choices = chunk.get("choices", []) or chunk.get("chunk", {}).get("choices", [])
|
||||
usage_metadata = _create_usage_metadata(token_usage, chunk.get("service_tier")) if token_usage else None
|
||||
|
||||
if len(choices) == 0:
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=default_chunk_class(content="", usage_metadata=usage_metadata),
|
||||
generation_info=base_generation_info,
|
||||
)
|
||||
if self.output_version == "v1":
|
||||
generation_chunk.message.content = []
|
||||
generation_chunk.message.response_metadata["output_version"] = "v1"
|
||||
return generation_chunk
|
||||
|
||||
choice = choices[0]
|
||||
delta = choice.get("delta")
|
||||
if delta is None:
|
||||
return None
|
||||
|
||||
message_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
generation_info = {**base_generation_info} if base_generation_info else {}
|
||||
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
if service_tier := chunk.get("service_tier"):
|
||||
generation_info["service_tier"] = service_tier
|
||||
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
|
||||
reasoning = _extract_reasoning_text(
|
||||
delta.get("reasoning_details"),
|
||||
strip_parts=False,
|
||||
)
|
||||
if isinstance(message_chunk, AIMessageChunk):
|
||||
if usage_metadata:
|
||||
message_chunk.usage_metadata = usage_metadata
|
||||
if reasoning:
|
||||
message_chunk = _with_reasoning_content(
|
||||
message_chunk,
|
||||
reasoning,
|
||||
preserve_whitespace=True,
|
||||
)
|
||||
|
||||
message_chunk.response_metadata["model_provider"] = "openai"
|
||||
return ChatGenerationChunk(
|
||||
message=message_chunk,
|
||||
generation_info=generation_info or None,
|
||||
)
|
||||
|
||||
def _create_chat_result(
|
||||
self,
|
||||
response: dict | Any,
|
||||
generation_info: dict | None = None,
|
||||
) -> ChatResult:
|
||||
result = super()._create_chat_result(response, generation_info)
|
||||
response_dict = response if isinstance(response, dict) else response.model_dump()
|
||||
choices = response_dict.get("choices", [])
|
||||
|
||||
generations: list[ChatGeneration] = []
|
||||
for index, generation in enumerate(result.generations):
|
||||
choice = choices[index] if index < len(choices) else {}
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage):
|
||||
content = message.content if isinstance(message.content, str) else None
|
||||
cleaned_content = content
|
||||
inline_reasoning = None
|
||||
if isinstance(content, str):
|
||||
cleaned_content, inline_reasoning = _strip_inline_think_tags(content)
|
||||
|
||||
choice_message = choice.get("message", {}) if isinstance(choice, Mapping) else {}
|
||||
split_reasoning = _extract_reasoning_text(choice_message.get("reasoning_details"))
|
||||
merged_reasoning = _merge_reasoning(split_reasoning, inline_reasoning)
|
||||
|
||||
updated_message = message
|
||||
if cleaned_content is not None and cleaned_content != message.content:
|
||||
updated_message = updated_message.model_copy(update={"content": cleaned_content})
|
||||
if merged_reasoning:
|
||||
updated_message = _with_reasoning_content(updated_message, merged_reasoning)
|
||||
|
||||
generation = ChatGeneration(
|
||||
message=updated_message,
|
||||
generation_info=generation.generation_info,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
||||
return ChatResult(generations=generations, llm_output=result.llm_output)
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Patched ChatOpenAI that preserves thought_signature for Gemini thinking models.
|
||||
|
||||
When using Gemini with thinking enabled via an OpenAI-compatible gateway (e.g.
|
||||
Vertex AI, Google AI Studio, or any proxy), the API requires that the
|
||||
``thought_signature`` field on tool-call objects is echoed back verbatim in
|
||||
every subsequent request.
|
||||
|
||||
The OpenAI-compatible gateway stores the raw tool-call dicts (including
|
||||
``thought_signature``) in ``additional_kwargs["tool_calls"]``, but standard
|
||||
``langchain_openai.ChatOpenAI`` only serialises the standard fields (``id``,
|
||||
``type``, ``function``) into the outgoing payload, silently dropping the
|
||||
signature. That causes an HTTP 400 ``INVALID_ARGUMENT`` error:
|
||||
|
||||
Unable to submit request because function call `<tool>` in the N. content
|
||||
block is missing a `thought_signature`.
|
||||
|
||||
This module fixes the problem by overriding ``_get_request_payload`` to
|
||||
re-inject tool-call signatures back into the outgoing payload for any assistant
|
||||
message that originally carried them.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class PatchedChatOpenAI(ChatOpenAI):
|
||||
"""ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway.
|
||||
|
||||
When using Gemini with thinking enabled via an OpenAI-compatible gateway,
|
||||
the API expects ``thought_signature`` to be present on tool-call objects in
|
||||
multi-turn conversations. This patched version restores those signatures
|
||||
from ``AIMessage.additional_kwargs["tool_calls"]`` into the serialised
|
||||
request payload before it is sent to the API.
|
||||
|
||||
Usage in ``config.yaml``::
|
||||
|
||||
- name: gemini-2.5-pro-thinking
|
||||
display_name: Gemini 2.5 Pro (Thinking)
|
||||
use: deerflow.models.patched_openai:PatchedChatOpenAI
|
||||
model: google/gemini-2.5-pro-preview
|
||||
api_key: $GEMINI_API_KEY
|
||||
base_url: https://<your-openai-compat-gateway>/v1
|
||||
max_tokens: 16384
|
||||
supports_thinking: true
|
||||
supports_vision: true
|
||||
when_thinking_enabled:
|
||||
extra_body:
|
||||
thinking:
|
||||
type: enabled
|
||||
"""
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Get request payload with ``thought_signature`` preserved on tool-call objects.
|
||||
|
||||
Overrides the parent method to re-inject ``thought_signature`` fields
|
||||
on tool-call objects that were stored in
|
||||
``additional_kwargs["tool_calls"]`` by LangChain but dropped during
|
||||
serialisation.
|
||||
"""
|
||||
# Capture the original LangChain messages *before* conversion so we can
|
||||
# access fields that the serialiser might drop.
|
||||
original_messages = self._convert_input(input_).to_messages()
|
||||
|
||||
# Obtain the base payload from the parent implementation.
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
|
||||
payload_messages = payload.get("messages", [])
|
||||
|
||||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
_restore_tool_call_signatures(payload_msg, orig_msg)
|
||||
else:
|
||||
# Fallback: match assistant-role entries positionally against AIMessages.
|
||||
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
||||
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
|
||||
for (_, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
|
||||
_restore_tool_call_signatures(payload_msg, ai_msg)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _restore_tool_call_signatures(payload_msg: dict, orig_msg: AIMessage) -> None:
|
||||
"""Re-inject ``thought_signature`` onto tool-call objects in *payload_msg*.
|
||||
|
||||
When the Gemini OpenAI-compatible gateway returns a response with function
|
||||
calls, each tool-call object may carry a ``thought_signature``. LangChain
|
||||
stores the raw tool-call dicts in ``additional_kwargs["tool_calls"]`` but
|
||||
only serialises the standard fields (``id``, ``type``, ``function``) into
|
||||
the outgoing payload, silently dropping the signature.
|
||||
|
||||
This function matches raw tool-call entries (by ``id``, falling back to
|
||||
positional order) and copies the signature back onto the serialised
|
||||
payload entries.
|
||||
"""
|
||||
raw_tool_calls: list[dict] = orig_msg.additional_kwargs.get("tool_calls") or []
|
||||
payload_tool_calls: list[dict] = payload_msg.get("tool_calls") or []
|
||||
|
||||
if not raw_tool_calls or not payload_tool_calls:
|
||||
return
|
||||
|
||||
# Build an id → raw_tc lookup for efficient matching.
|
||||
raw_by_id: dict[str, dict] = {}
|
||||
for raw_tc in raw_tool_calls:
|
||||
tc_id = raw_tc.get("id")
|
||||
if tc_id:
|
||||
raw_by_id[tc_id] = raw_tc
|
||||
|
||||
for idx, payload_tc in enumerate(payload_tool_calls):
|
||||
# Try matching by id first, then fall back to positional.
|
||||
raw_tc = raw_by_id.get(payload_tc.get("id", ""))
|
||||
if raw_tc is None and idx < len(raw_tool_calls):
|
||||
raw_tc = raw_tool_calls[idx]
|
||||
|
||||
if raw_tc is None:
|
||||
continue
|
||||
|
||||
# The gateway may use either snake_case or camelCase.
|
||||
sig = raw_tc.get("thought_signature") or raw_tc.get("thoughtSignature")
|
||||
if sig:
|
||||
payload_tc["thought_signature"] = sig
|
||||
@@ -0,0 +1,258 @@
|
||||
"""Custom vLLM provider built on top of LangChain ChatOpenAI.
|
||||
|
||||
vLLM 0.19.0 exposes reasoning models through an OpenAI-compatible API, but
|
||||
LangChain's default OpenAI adapter drops the non-standard ``reasoning`` field
|
||||
from assistant messages and streaming deltas. That breaks interleaved
|
||||
thinking/tool-call flows because vLLM expects the assistant's prior reasoning to
|
||||
be echoed back on subsequent turns.
|
||||
|
||||
This provider preserves ``reasoning`` on:
|
||||
- non-streaming responses
|
||||
- streaming deltas
|
||||
- multi-turn request payloads
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessageChunk,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
HumanMessageChunk,
|
||||
SystemMessageChunk,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.tool import tool_call_chunk
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models.base import _create_usage_metadata
|
||||
|
||||
|
||||
def _normalize_vllm_chat_template_kwargs(payload: dict[str, Any]) -> None:
|
||||
"""Map DeerFlow's legacy ``thinking`` toggle to vLLM/Qwen's ``enable_thinking``.
|
||||
|
||||
DeerFlow originally documented ``extra_body.chat_template_kwargs.thinking``
|
||||
for vLLM, but vLLM 0.19.0's Qwen reasoning parser reads
|
||||
``chat_template_kwargs.enable_thinking``. Normalize the payload just before
|
||||
it is sent so existing configs keep working and flash mode can truly
|
||||
disable reasoning.
|
||||
"""
|
||||
extra_body = payload.get("extra_body")
|
||||
if not isinstance(extra_body, dict):
|
||||
return
|
||||
|
||||
chat_template_kwargs = extra_body.get("chat_template_kwargs")
|
||||
if not isinstance(chat_template_kwargs, dict):
|
||||
return
|
||||
|
||||
if "thinking" not in chat_template_kwargs:
|
||||
return
|
||||
|
||||
normalized_chat_template_kwargs = dict(chat_template_kwargs)
|
||||
normalized_chat_template_kwargs.setdefault("enable_thinking", normalized_chat_template_kwargs["thinking"])
|
||||
normalized_chat_template_kwargs.pop("thinking", None)
|
||||
extra_body["chat_template_kwargs"] = normalized_chat_template_kwargs
|
||||
|
||||
|
||||
def _reasoning_to_text(reasoning: Any) -> str:
|
||||
"""Best-effort extraction of readable reasoning text from vLLM payloads."""
|
||||
if isinstance(reasoning, str):
|
||||
return reasoning
|
||||
|
||||
if isinstance(reasoning, list):
|
||||
parts = [_reasoning_to_text(item) for item in reasoning]
|
||||
return "".join(part for part in parts if part)
|
||||
|
||||
if isinstance(reasoning, dict):
|
||||
for key in ("text", "content", "reasoning"):
|
||||
value = reasoning.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if value is not None:
|
||||
text = _reasoning_to_text(value)
|
||||
if text:
|
||||
return text
|
||||
try:
|
||||
return json.dumps(reasoning, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(reasoning)
|
||||
|
||||
try:
|
||||
return json.dumps(reasoning, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(reasoning)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk_with_reasoning(_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]) -> BaseMessageChunk:
|
||||
"""Convert a streaming delta to a LangChain message chunk while preserving reasoning."""
|
||||
id_ = _dict.get("id")
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: dict[str, Any] = {}
|
||||
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
function_call["name"] = ""
|
||||
additional_kwargs["function_call"] = function_call
|
||||
|
||||
reasoning = _dict.get("reasoning")
|
||||
if reasoning is not None:
|
||||
additional_kwargs["reasoning"] = reasoning
|
||||
reasoning_text = _reasoning_to_text(reasoning)
|
||||
if reasoning_text:
|
||||
additional_kwargs["reasoning_content"] = reasoning_text
|
||||
|
||||
tool_call_chunks = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
try:
|
||||
tool_call_chunks = [
|
||||
tool_call_chunk(
|
||||
name=rtc["function"].get("name"),
|
||||
args=rtc["function"].get("arguments"),
|
||||
id=rtc.get("id"),
|
||||
index=rtc["index"],
|
||||
)
|
||||
for rtc in raw_tool_calls
|
||||
]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content, id=id_)
|
||||
if role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
id=id_,
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||
)
|
||||
if role in ("system", "developer") or default_class == SystemMessageChunk:
|
||||
role_kwargs = {"__openai_role__": "developer"} if role == "developer" else {}
|
||||
return SystemMessageChunk(content=content, id=id_, additional_kwargs=role_kwargs)
|
||||
if role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
|
||||
if role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"], id=id_)
|
||||
if role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role, id=id_) # type: ignore[arg-type]
|
||||
return default_class(content=content, id=id_) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def _restore_reasoning_field(payload_msg: dict[str, Any], orig_msg: AIMessage) -> None:
|
||||
"""Re-inject vLLM reasoning onto outgoing assistant messages."""
|
||||
reasoning = orig_msg.additional_kwargs.get("reasoning")
|
||||
if reasoning is None:
|
||||
reasoning = orig_msg.additional_kwargs.get("reasoning_content")
|
||||
if reasoning is not None:
|
||||
payload_msg["reasoning"] = reasoning
|
||||
|
||||
|
||||
class VllmChatModel(ChatOpenAI):
|
||||
"""ChatOpenAI variant that preserves vLLM reasoning fields across turns."""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "vllm-openai-compatible"
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Restore assistant reasoning in request payloads for interleaved thinking."""
|
||||
original_messages = self._convert_input(input_).to_messages()
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
_normalize_vllm_chat_template_kwargs(payload)
|
||||
payload_messages = payload.get("messages", [])
|
||||
|
||||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
_restore_reasoning_field(payload_msg, orig_msg)
|
||||
else:
|
||||
ai_messages = [message for message in original_messages if isinstance(message, AIMessage)]
|
||||
assistant_payloads = [message for message in payload_messages if message.get("role") == "assistant"]
|
||||
for payload_msg, ai_msg in zip(assistant_payloads, ai_messages):
|
||||
_restore_reasoning_field(payload_msg, ai_msg)
|
||||
|
||||
return payload
|
||||
|
||||
def _create_chat_result(self, response: dict | openai.BaseModel, generation_info: dict | None = None) -> ChatResult:
|
||||
"""Preserve vLLM reasoning on non-streaming responses."""
|
||||
result = super()._create_chat_result(response, generation_info=generation_info)
|
||||
response_dict = response if isinstance(response, dict) else response.model_dump()
|
||||
|
||||
for generation, choice in zip(result.generations, response_dict.get("choices", [])):
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
continue
|
||||
message = generation.message
|
||||
if not isinstance(message, AIMessage):
|
||||
continue
|
||||
reasoning = choice.get("message", {}).get("reasoning")
|
||||
if reasoning is None:
|
||||
continue
|
||||
message.additional_kwargs["reasoning"] = reasoning
|
||||
reasoning_text = _reasoning_to_text(reasoning)
|
||||
if reasoning_text:
|
||||
message.additional_kwargs["reasoning_content"] = reasoning_text
|
||||
|
||||
return result
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
chunk: dict,
|
||||
default_chunk_class: type,
|
||||
base_generation_info: dict | None,
|
||||
) -> ChatGenerationChunk | None:
|
||||
"""Preserve vLLM reasoning on streaming deltas."""
|
||||
if chunk.get("type") == "content.delta":
|
||||
return None
|
||||
|
||||
token_usage = chunk.get("usage")
|
||||
choices = chunk.get("choices", []) or chunk.get("chunk", {}).get("choices", [])
|
||||
usage_metadata = _create_usage_metadata(token_usage, chunk.get("service_tier")) if token_usage else None
|
||||
|
||||
if len(choices) == 0:
|
||||
generation_chunk = ChatGenerationChunk(message=default_chunk_class(content="", usage_metadata=usage_metadata), generation_info=base_generation_info)
|
||||
if self.output_version == "v1":
|
||||
generation_chunk.message.content = []
|
||||
generation_chunk.message.response_metadata["output_version"] = "v1"
|
||||
return generation_chunk
|
||||
|
||||
choice = choices[0]
|
||||
if choice["delta"] is None:
|
||||
return None
|
||||
|
||||
message_chunk = _convert_delta_to_message_chunk_with_reasoning(choice["delta"], default_chunk_class)
|
||||
generation_info = {**base_generation_info} if base_generation_info else {}
|
||||
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
if service_tier := chunk.get("service_tier"):
|
||||
generation_info["service_tier"] = service_tier
|
||||
|
||||
if logprobs := choice.get("logprobs"):
|
||||
generation_info["logprobs"] = logprobs
|
||||
|
||||
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
|
||||
message_chunk.usage_metadata = usage_metadata
|
||||
|
||||
message_chunk.response_metadata["model_provider"] = "openai"
|
||||
return ChatGenerationChunk(message=message_chunk, generation_info=generation_info or None)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .resolvers import resolve_class, resolve_variable
|
||||
|
||||
__all__ = ["resolve_class", "resolve_variable"]
|
||||
@@ -0,0 +1,95 @@
|
||||
from importlib import import_module
|
||||
|
||||
MODULE_TO_PACKAGE_HINTS = {
|
||||
"langchain_google_genai": "langchain-google-genai",
|
||||
"langchain_anthropic": "langchain-anthropic",
|
||||
"langchain_openai": "langchain-openai",
|
||||
"langchain_deepseek": "langchain-deepseek",
|
||||
}
|
||||
|
||||
|
||||
def _build_missing_dependency_hint(module_path: str, err: ImportError) -> str:
|
||||
"""Build an actionable hint when module import fails."""
|
||||
module_root = module_path.split(".", 1)[0]
|
||||
missing_module = getattr(err, "name", None) or module_root
|
||||
|
||||
# Prefer provider package hints for known integrations, even when the import
|
||||
# error is triggered by a transitive dependency (e.g. `google`).
|
||||
package_name = MODULE_TO_PACKAGE_HINTS.get(module_root)
|
||||
if package_name is None:
|
||||
package_name = MODULE_TO_PACKAGE_HINTS.get(missing_module, missing_module.replace("_", "-"))
|
||||
|
||||
return f"Missing dependency '{missing_module}'. Install it with `uv add {package_name}` (or `pip install {package_name}`), then restart DeerFlow."
|
||||
|
||||
|
||||
def resolve_variable[T](
|
||||
variable_path: str,
|
||||
expected_type: type[T] | tuple[type, ...] | None = None,
|
||||
) -> T:
|
||||
"""Resolve a variable from a path.
|
||||
|
||||
Args:
|
||||
variable_path: The path to the variable (e.g. "parent_package_name.sub_package_name.module_name:variable_name").
|
||||
expected_type: Optional type or tuple of types to validate the resolved variable against.
|
||||
If provided, uses isinstance() to check if the variable is an instance of the expected type(s).
|
||||
|
||||
Returns:
|
||||
The resolved variable.
|
||||
|
||||
Raises:
|
||||
ImportError: If the module path is invalid or the attribute doesn't exist.
|
||||
ValueError: If the resolved variable doesn't pass the validation checks.
|
||||
"""
|
||||
try:
|
||||
module_path, variable_name = variable_path.rsplit(":", 1)
|
||||
except ValueError as err:
|
||||
raise ImportError(f"{variable_path} doesn't look like a variable path. Example: parent_package_name.sub_package_name.module_name:variable_name") from err
|
||||
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
except ImportError as err:
|
||||
module_root = module_path.split(".", 1)[0]
|
||||
err_name = getattr(err, "name", None)
|
||||
if isinstance(err, ModuleNotFoundError) or err_name == module_root:
|
||||
hint = _build_missing_dependency_hint(module_path, err)
|
||||
raise ImportError(f"Could not import module {module_path}. {hint}") from err
|
||||
# Preserve the original ImportError message for non-missing-module failures.
|
||||
raise ImportError(f"Error importing module {module_path}: {err}") from err
|
||||
|
||||
try:
|
||||
variable = getattr(module, variable_name)
|
||||
except AttributeError as err:
|
||||
raise ImportError(f"Module {module_path} does not define a {variable_name} attribute/class") from err
|
||||
|
||||
# Type validation
|
||||
if expected_type is not None:
|
||||
if not isinstance(variable, expected_type):
|
||||
type_name = expected_type.__name__ if isinstance(expected_type, type) else " or ".join(t.__name__ for t in expected_type)
|
||||
raise ValueError(f"{variable_path} is not an instance of {type_name}, got {type(variable).__name__}")
|
||||
|
||||
return variable
|
||||
|
||||
|
||||
def resolve_class[T](class_path: str, base_class: type[T] | None = None) -> type[T]:
|
||||
"""Resolve a class from a module path and class name.
|
||||
|
||||
Args:
|
||||
class_path: The path to the class (e.g. "langchain_openai:ChatOpenAI").
|
||||
base_class: The base class to check if the resolved class is a subclass of.
|
||||
|
||||
Returns:
|
||||
The resolved class.
|
||||
|
||||
Raises:
|
||||
ImportError: If the module path is invalid or the attribute doesn't exist.
|
||||
ValueError: If the resolved object is not a class or not a subclass of base_class.
|
||||
"""
|
||||
model_class = resolve_variable(class_path, expected_type=type)
|
||||
|
||||
if not isinstance(model_class, type):
|
||||
raise ValueError(f"{class_path} is not a valid class")
|
||||
|
||||
if base_class is not None and not issubclass(model_class, base_class):
|
||||
raise ValueError(f"{class_path} is not a subclass of {base_class.__name__}")
|
||||
|
||||
return model_class
|
||||
@@ -0,0 +1,39 @@
|
||||
"""LangGraph-compatible runtime — runs, streaming, and lifecycle management.
|
||||
|
||||
Re-exports the public API of :mod:`~deerflow.runtime.runs` and
|
||||
:mod:`~deerflow.runtime.stream_bridge` so that consumers can import
|
||||
directly from ``deerflow.runtime``.
|
||||
"""
|
||||
|
||||
from .runs import ConflictError, DisconnectMode, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
||||
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
|
||||
from .store import get_store, make_store, reset_store, store_context
|
||||
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
|
||||
|
||||
__all__ = [
|
||||
# runs
|
||||
"ConflictError",
|
||||
"DisconnectMode",
|
||||
"RunManager",
|
||||
"RunRecord",
|
||||
"RunStatus",
|
||||
"UnsupportedStrategyError",
|
||||
"run_agent",
|
||||
# serialization
|
||||
"serialize",
|
||||
"serialize_channel_values",
|
||||
"serialize_lc_object",
|
||||
"serialize_messages_tuple",
|
||||
# store
|
||||
"get_store",
|
||||
"make_store",
|
||||
"reset_store",
|
||||
"store_context",
|
||||
# stream_bridge
|
||||
"END_SENTINEL",
|
||||
"HEARTBEAT_SENTINEL",
|
||||
"MemoryStreamBridge",
|
||||
"StreamBridge",
|
||||
"StreamEvent",
|
||||
"make_stream_bridge",
|
||||
]
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Run lifecycle management for LangGraph Platform API compatibility."""
|
||||
|
||||
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
|
||||
from .schemas import DisconnectMode, RunStatus
|
||||
from .worker import run_agent
|
||||
|
||||
__all__ = [
|
||||
"ConflictError",
|
||||
"DisconnectMode",
|
||||
"RunManager",
|
||||
"RunRecord",
|
||||
"RunStatus",
|
||||
"UnsupportedStrategyError",
|
||||
"run_agent",
|
||||
]
|
||||
@@ -0,0 +1,210 @@
|
||||
"""In-memory run registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from .schemas import DisconnectMode, RunStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(UTC).isoformat()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunRecord:
|
||||
"""Mutable record for a single run."""
|
||||
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None
|
||||
status: RunStatus
|
||||
on_disconnect: DisconnectMode
|
||||
multitask_strategy: str = "reject"
|
||||
metadata: dict = field(default_factory=dict)
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
task: asyncio.Task | None = field(default=None, repr=False)
|
||||
abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
|
||||
abort_action: str = "interrupt"
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class RunManager:
|
||||
"""In-memory run registry. All mutations are protected by an asyncio lock."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._runs: dict[str, RunRecord] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
*,
|
||||
on_disconnect: DisconnectMode = DisconnectMode.cancel,
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
) -> RunRecord:
|
||||
"""Create a new pending run and register it."""
|
||||
run_id = str(uuid.uuid4())
|
||||
now = _now_iso()
|
||||
record = RunRecord(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
status=RunStatus.pending,
|
||||
on_disconnect=on_disconnect,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
async with self._lock:
|
||||
self._runs[run_id] = record
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
def get(self, run_id: str) -> RunRecord | None:
|
||||
"""Return a run record by ID, or ``None``."""
|
||||
return self._runs.get(run_id)
|
||||
|
||||
async def list_by_thread(self, thread_id: str) -> list[RunRecord]:
|
||||
"""Return all runs for a given thread, newest first."""
|
||||
async with self._lock:
|
||||
# Dict insertion order matches creation order, so reversing it gives
|
||||
# us deterministic newest-first results even when timestamps tie.
|
||||
return [r for r in reversed(self._runs.values()) if r.thread_id == thread_id]
|
||||
|
||||
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
||||
"""Transition a run to a new status."""
|
||||
async with self._lock:
|
||||
record = self._runs.get(run_id)
|
||||
if record is None:
|
||||
logger.warning("set_status called for unknown run %s", run_id)
|
||||
return
|
||||
record.status = status
|
||||
record.updated_at = _now_iso()
|
||||
if error is not None:
|
||||
record.error = error
|
||||
logger.info("Run %s -> %s", run_id, status.value)
|
||||
|
||||
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
|
||||
"""Request cancellation of a run.
|
||||
|
||||
Args:
|
||||
run_id: The run ID to cancel.
|
||||
action: "interrupt" keeps checkpoint, "rollback" reverts to pre-run state.
|
||||
|
||||
Sets the abort event with the action reason and cancels the asyncio task.
|
||||
Returns ``True`` if the run was in-flight and cancellation was initiated.
|
||||
"""
|
||||
async with self._lock:
|
||||
record = self._runs.get(run_id)
|
||||
if record is None:
|
||||
return False
|
||||
if record.status not in (RunStatus.pending, RunStatus.running):
|
||||
return False
|
||||
record.abort_action = action
|
||||
record.abort_event.set()
|
||||
if record.task is not None and not record.task.done():
|
||||
record.task.cancel()
|
||||
record.status = RunStatus.interrupted
|
||||
record.updated_at = _now_iso()
|
||||
logger.info("Run %s cancelled (action=%s)", run_id, action)
|
||||
return True
|
||||
|
||||
async def create_or_reject(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
*,
|
||||
on_disconnect: DisconnectMode = DisconnectMode.cancel,
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
) -> RunRecord:
|
||||
"""Atomically check for inflight runs and create a new one.
|
||||
|
||||
For ``reject`` strategy, raises ``ConflictError`` if thread
|
||||
already has a pending/running run. For ``interrupt``/``rollback``,
|
||||
cancels inflight runs before creating.
|
||||
|
||||
This method holds the lock across both the check and the insert,
|
||||
eliminating the TOCTOU race in separate ``has_inflight`` + ``create``.
|
||||
"""
|
||||
run_id = str(uuid.uuid4())
|
||||
now = _now_iso()
|
||||
|
||||
_supported_strategies = ("reject", "interrupt", "rollback")
|
||||
|
||||
async with self._lock:
|
||||
if multitask_strategy not in _supported_strategies:
|
||||
raise UnsupportedStrategyError(f"Multitask strategy '{multitask_strategy}' is not yet supported. Supported strategies: {', '.join(_supported_strategies)}")
|
||||
|
||||
inflight = [r for r in self._runs.values() if r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running)]
|
||||
|
||||
if multitask_strategy == "reject" and inflight:
|
||||
raise ConflictError(f"Thread {thread_id} already has an active run")
|
||||
|
||||
if multitask_strategy in ("interrupt", "rollback") and inflight:
|
||||
for r in inflight:
|
||||
r.abort_action = multitask_strategy
|
||||
r.abort_event.set()
|
||||
if r.task is not None and not r.task.done():
|
||||
r.task.cancel()
|
||||
r.status = RunStatus.interrupted
|
||||
r.updated_at = now
|
||||
logger.info(
|
||||
"Cancelled %d inflight run(s) on thread %s (strategy=%s)",
|
||||
len(inflight),
|
||||
thread_id,
|
||||
multitask_strategy,
|
||||
)
|
||||
|
||||
record = RunRecord(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
status=RunStatus.pending,
|
||||
on_disconnect=on_disconnect,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
self._runs[run_id] = record
|
||||
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
async def has_inflight(self, thread_id: str) -> bool:
|
||||
"""Return ``True`` if *thread_id* has a pending or running run."""
|
||||
async with self._lock:
|
||||
return any(r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running) for r in self._runs.values())
|
||||
|
||||
async def cleanup(self, run_id: str, *, delay: float = 300) -> None:
|
||||
"""Remove a run record after an optional delay."""
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
async with self._lock:
|
||||
self._runs.pop(run_id, None)
|
||||
logger.debug("Run record %s cleaned up", run_id)
|
||||
|
||||
|
||||
class ConflictError(Exception):
|
||||
"""Raised when multitask_strategy=reject and thread has inflight runs."""
|
||||
|
||||
|
||||
class UnsupportedStrategyError(Exception):
|
||||
"""Raised when a multitask_strategy value is not yet implemented."""
|
||||
@@ -0,0 +1,21 @@
|
||||
"""Run status and disconnect mode enums."""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class RunStatus(StrEnum):
|
||||
"""Lifecycle status of a single run."""
|
||||
|
||||
pending = "pending"
|
||||
running = "running"
|
||||
success = "success"
|
||||
error = "error"
|
||||
timeout = "timeout"
|
||||
interrupted = "interrupted"
|
||||
|
||||
|
||||
class DisconnectMode(StrEnum):
|
||||
"""Behaviour when the SSE consumer disconnects."""
|
||||
|
||||
cancel = "cancel"
|
||||
continue_ = "continue"
|
||||
@@ -0,0 +1,381 @@
|
||||
"""Background agent execution.
|
||||
|
||||
Runs an agent graph inside an ``asyncio.Task``, publishing events to
|
||||
a :class:`StreamBridge` as they are produced.
|
||||
|
||||
Uses ``graph.astream(stream_mode=[...])`` which gives correct full-state
|
||||
snapshots for ``values`` mode, proper ``{node: writes}`` for ``updates``,
|
||||
and ``(chunk, metadata)`` tuples for ``messages`` mode.
|
||||
|
||||
Note: ``events`` mode is not supported through the gateway — it requires
|
||||
``graph.astream_events()`` which cannot simultaneously produce ``values``
|
||||
snapshots. The JS open-source LangGraph API server works around this via
|
||||
internal checkpoint callbacks that are not exposed in the Python public API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from deerflow.runtime.serialization import serialize
|
||||
from deerflow.runtime.stream_bridge import StreamBridge
|
||||
|
||||
from .manager import RunManager, RunRecord
|
||||
from .schemas import RunStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Valid stream_mode values for LangGraph's graph.astream()
|
||||
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
|
||||
|
||||
|
||||
async def run_agent(
|
||||
bridge: StreamBridge,
|
||||
run_manager: RunManager,
|
||||
record: RunRecord,
|
||||
*,
|
||||
checkpointer: Any,
|
||||
store: Any | None = None,
|
||||
agent_factory: Any,
|
||||
graph_input: dict,
|
||||
config: dict,
|
||||
stream_modes: list[str] | None = None,
|
||||
stream_subgraphs: bool = False,
|
||||
interrupt_before: list[str] | Literal["*"] | None = None,
|
||||
interrupt_after: list[str] | Literal["*"] | None = None,
|
||||
) -> None:
|
||||
"""Execute an agent in the background, publishing events to *bridge*."""
|
||||
|
||||
run_id = record.run_id
|
||||
thread_id = record.thread_id
|
||||
requested_modes: set[str] = set(stream_modes or ["values"])
|
||||
pre_run_checkpoint_id: str | None = None
|
||||
pre_run_snapshot: dict[str, Any] | None = None
|
||||
snapshot_capture_failed = False
|
||||
|
||||
# Track whether "events" was requested but skipped
|
||||
if "events" in requested_modes:
|
||||
logger.info(
|
||||
"Run %s: 'events' stream_mode not supported in gateway (requires astream_events + checkpoint callbacks). Skipping.",
|
||||
run_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# 1. Mark running
|
||||
await run_manager.set_status(run_id, RunStatus.running)
|
||||
|
||||
# Snapshot the latest pre-run checkpoint so rollback can restore it.
|
||||
if checkpointer is not None:
|
||||
try:
|
||||
config_for_check = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await checkpointer.aget_tuple(config_for_check)
|
||||
if ckpt_tuple is not None:
|
||||
ckpt_config = getattr(ckpt_tuple, "config", {}).get("configurable", {})
|
||||
pre_run_checkpoint_id = ckpt_config.get("checkpoint_id")
|
||||
pre_run_snapshot = {
|
||||
"checkpoint_ns": ckpt_config.get("checkpoint_ns", ""),
|
||||
"checkpoint": copy.deepcopy(getattr(ckpt_tuple, "checkpoint", {})),
|
||||
"metadata": copy.deepcopy(getattr(ckpt_tuple, "metadata", {})),
|
||||
"pending_writes": copy.deepcopy(getattr(ckpt_tuple, "pending_writes", []) or []),
|
||||
}
|
||||
except Exception:
|
||||
snapshot_capture_failed = True
|
||||
logger.warning("Could not capture pre-run checkpoint snapshot for run %s", run_id, exc_info=True)
|
||||
|
||||
# 2. Publish metadata — useStream needs both run_id AND thread_id
|
||||
await bridge.publish(
|
||||
run_id,
|
||||
"metadata",
|
||||
{
|
||||
"run_id": run_id,
|
||||
"thread_id": thread_id,
|
||||
},
|
||||
)
|
||||
|
||||
# 3. Build the agent
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# Inject runtime context so middlewares can access thread_id
|
||||
# (langgraph-cli does this automatically; we must do it manually)
|
||||
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
||||
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
|
||||
# prefers it over ``configurable`` for thread-level data), make
|
||||
# sure ``thread_id`` is available there too.
|
||||
if "context" in config and isinstance(config["context"], dict):
|
||||
config["context"].setdefault("thread_id", thread_id)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
runnable_config = RunnableConfig(**config)
|
||||
agent = agent_factory(config=runnable_config)
|
||||
|
||||
# 4. Attach checkpointer and store
|
||||
if checkpointer is not None:
|
||||
agent.checkpointer = checkpointer
|
||||
if store is not None:
|
||||
agent.store = store
|
||||
|
||||
# 5. Set interrupt nodes
|
||||
if interrupt_before:
|
||||
agent.interrupt_before_nodes = interrupt_before
|
||||
if interrupt_after:
|
||||
agent.interrupt_after_nodes = interrupt_after
|
||||
|
||||
# 6. Build LangGraph stream_mode list
|
||||
# "events" is NOT a valid astream mode — skip it
|
||||
# "messages-tuple" maps to LangGraph's "messages" mode
|
||||
lg_modes: list[str] = []
|
||||
for m in requested_modes:
|
||||
if m == "messages-tuple":
|
||||
lg_modes.append("messages")
|
||||
elif m == "events":
|
||||
# Skipped — see log above
|
||||
continue
|
||||
elif m in _VALID_LG_MODES:
|
||||
lg_modes.append(m)
|
||||
if not lg_modes:
|
||||
lg_modes = ["values"]
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen: set[str] = set()
|
||||
deduped: list[str] = []
|
||||
for m in lg_modes:
|
||||
if m not in seen:
|
||||
seen.add(m)
|
||||
deduped.append(m)
|
||||
lg_modes = deduped
|
||||
|
||||
logger.info("Run %s: streaming with modes %s (requested: %s)", run_id, lg_modes, requested_modes)
|
||||
|
||||
# 7. Stream using graph.astream
|
||||
if len(lg_modes) == 1 and not stream_subgraphs:
|
||||
# Single mode, no subgraphs: astream yields raw chunks
|
||||
single_mode = lg_modes[0]
|
||||
async for chunk in agent.astream(graph_input, config=runnable_config, stream_mode=single_mode):
|
||||
if record.abort_event.is_set():
|
||||
logger.info("Run %s abort requested — stopping", run_id)
|
||||
break
|
||||
sse_event = _lg_mode_to_sse_event(single_mode)
|
||||
await bridge.publish(run_id, sse_event, serialize(chunk, mode=single_mode))
|
||||
else:
|
||||
# Multiple modes or subgraphs: astream yields tuples
|
||||
async for item in agent.astream(
|
||||
graph_input,
|
||||
config=runnable_config,
|
||||
stream_mode=lg_modes,
|
||||
subgraphs=stream_subgraphs,
|
||||
):
|
||||
if record.abort_event.is_set():
|
||||
logger.info("Run %s abort requested — stopping", run_id)
|
||||
break
|
||||
|
||||
mode, chunk = _unpack_stream_item(item, lg_modes, stream_subgraphs)
|
||||
if mode is None:
|
||||
continue
|
||||
|
||||
sse_event = _lg_mode_to_sse_event(mode)
|
||||
await bridge.publish(run_id, sse_event, serialize(chunk, mode=mode))
|
||||
|
||||
# 8. Final status
|
||||
if record.abort_event.is_set():
|
||||
action = record.abort_action
|
||||
if action == "rollback":
|
||||
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
|
||||
try:
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
pre_run_checkpoint_id=pre_run_checkpoint_id,
|
||||
pre_run_snapshot=pre_run_snapshot,
|
||||
snapshot_capture_failed=snapshot_capture_failed,
|
||||
)
|
||||
logger.info("Run %s rolled back to pre-run checkpoint %s", run_id, pre_run_checkpoint_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True)
|
||||
else:
|
||||
await run_manager.set_status(run_id, RunStatus.interrupted)
|
||||
else:
|
||||
await run_manager.set_status(run_id, RunStatus.success)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
action = record.abort_action
|
||||
if action == "rollback":
|
||||
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
|
||||
try:
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
pre_run_checkpoint_id=pre_run_checkpoint_id,
|
||||
pre_run_snapshot=pre_run_snapshot,
|
||||
snapshot_capture_failed=snapshot_capture_failed,
|
||||
)
|
||||
logger.info("Run %s was cancelled and rolled back", run_id)
|
||||
except Exception:
|
||||
logger.warning("Run %s cancellation rollback failed", run_id, exc_info=True)
|
||||
else:
|
||||
await run_manager.set_status(run_id, RunStatus.interrupted)
|
||||
logger.info("Run %s was cancelled", run_id)
|
||||
|
||||
except Exception as exc:
|
||||
error_msg = f"{exc}"
|
||||
logger.exception("Run %s failed: %s", run_id, error_msg)
|
||||
await run_manager.set_status(run_id, RunStatus.error, error=error_msg)
|
||||
await bridge.publish(
|
||||
run_id,
|
||||
"error",
|
||||
{
|
||||
"message": error_msg,
|
||||
"name": type(exc).__name__,
|
||||
},
|
||||
)
|
||||
|
||||
finally:
|
||||
await bridge.publish_end(run_id)
|
||||
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _call_checkpointer_method(checkpointer: Any, async_name: str, sync_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call a checkpointer method, supporting async and sync variants."""
|
||||
method = getattr(checkpointer, async_name, None) or getattr(checkpointer, sync_name, None)
|
||||
if method is None:
|
||||
raise AttributeError(f"Missing checkpointer method: {async_name}/{sync_name}")
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
|
||||
async def _rollback_to_pre_run_checkpoint(
|
||||
*,
|
||||
checkpointer: Any,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
pre_run_checkpoint_id: str | None,
|
||||
pre_run_snapshot: dict[str, Any] | None,
|
||||
snapshot_capture_failed: bool,
|
||||
) -> None:
|
||||
"""Restore thread state to the checkpoint snapshot captured before run start."""
|
||||
if checkpointer is None:
|
||||
logger.info("Run %s rollback requested but no checkpointer is configured", run_id)
|
||||
return
|
||||
|
||||
if snapshot_capture_failed:
|
||||
logger.warning("Run %s rollback skipped: pre-run checkpoint snapshot capture failed", run_id)
|
||||
return
|
||||
|
||||
if pre_run_snapshot is None:
|
||||
await _call_checkpointer_method(checkpointer, "adelete_thread", "delete_thread", thread_id)
|
||||
logger.info("Run %s rollback reset thread %s to empty state", run_id, thread_id)
|
||||
return
|
||||
|
||||
checkpoint_to_restore = None
|
||||
metadata_to_restore: dict[str, Any] = {}
|
||||
checkpoint_ns = ""
|
||||
checkpoint = pre_run_snapshot.get("checkpoint")
|
||||
if not isinstance(checkpoint, dict):
|
||||
logger.warning("Run %s rollback skipped: invalid pre-run checkpoint snapshot", run_id)
|
||||
return
|
||||
checkpoint_to_restore = checkpoint
|
||||
if checkpoint_to_restore.get("id") is None and pre_run_checkpoint_id is not None:
|
||||
checkpoint_to_restore = {**checkpoint_to_restore, "id": pre_run_checkpoint_id}
|
||||
if checkpoint_to_restore.get("id") is None:
|
||||
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
|
||||
return
|
||||
metadata = pre_run_snapshot.get("metadata", {})
|
||||
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
|
||||
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
|
||||
checkpoint_ns = raw_checkpoint_ns if isinstance(raw_checkpoint_ns, str) else ""
|
||||
|
||||
channel_versions = checkpoint_to_restore.get("channel_versions")
|
||||
new_versions = dict(channel_versions) if isinstance(channel_versions, dict) else {}
|
||||
|
||||
restore_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}}
|
||||
restored_config = await _call_checkpointer_method(
|
||||
checkpointer,
|
||||
"aput",
|
||||
"put",
|
||||
restore_config,
|
||||
checkpoint_to_restore,
|
||||
metadata_to_restore if isinstance(metadata_to_restore, dict) else {},
|
||||
new_versions,
|
||||
)
|
||||
if not isinstance(restored_config, dict):
|
||||
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config: expected dict")
|
||||
restored_configurable = restored_config.get("configurable", {})
|
||||
if not isinstance(restored_configurable, dict):
|
||||
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config payload")
|
||||
restored_checkpoint_id = restored_configurable.get("checkpoint_id")
|
||||
if not restored_checkpoint_id:
|
||||
raise RuntimeError(f"Run {run_id} rollback restore did not return checkpoint_id")
|
||||
|
||||
pending_writes = pre_run_snapshot.get("pending_writes", [])
|
||||
if not pending_writes:
|
||||
return
|
||||
|
||||
writes_by_task: dict[str, list[tuple[str, Any]]] = {}
|
||||
for item in pending_writes:
|
||||
if not isinstance(item, (tuple, list)) or len(item) != 3:
|
||||
raise RuntimeError(f"Run {run_id} rollback failed: pending_write is not a 3-tuple: {item!r}")
|
||||
task_id, channel, value = item
|
||||
if not isinstance(channel, str):
|
||||
raise RuntimeError(f"Run {run_id} rollback failed: pending_write has non-string channel: task_id={task_id!r}, channel={channel!r}")
|
||||
writes_by_task.setdefault(str(task_id), []).append((channel, value))
|
||||
|
||||
for task_id, writes in writes_by_task.items():
|
||||
await _call_checkpointer_method(
|
||||
checkpointer,
|
||||
"aput_writes",
|
||||
"put_writes",
|
||||
restored_config,
|
||||
writes,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
|
||||
def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
"""Map LangGraph internal stream_mode name to SSE event name.
|
||||
|
||||
LangGraph's ``astream(stream_mode="messages")`` produces message
|
||||
tuples. The SSE protocol calls this ``messages-tuple`` when the
|
||||
client explicitly requests it, but the default SSE event name used
|
||||
by LangGraph Platform is simply ``"messages"``.
|
||||
"""
|
||||
# All LG modes map 1:1 to SSE event names — "messages" stays "messages"
|
||||
return mode
|
||||
|
||||
|
||||
def _unpack_stream_item(
|
||||
item: Any,
|
||||
lg_modes: list[str],
|
||||
stream_subgraphs: bool,
|
||||
) -> tuple[str | None, Any]:
|
||||
"""Unpack a multi-mode or subgraph stream item into (mode, chunk).
|
||||
|
||||
Returns ``(None, None)`` if the item cannot be parsed.
|
||||
"""
|
||||
if stream_subgraphs:
|
||||
if isinstance(item, tuple) and len(item) == 3:
|
||||
_ns, mode, chunk = item
|
||||
return str(mode), chunk
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
mode, chunk = item
|
||||
return str(mode), chunk
|
||||
return None, None
|
||||
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
mode, chunk = item
|
||||
return str(mode), chunk
|
||||
|
||||
# Fallback: single-element output from first mode
|
||||
return lg_modes[0] if lg_modes else None, item
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user