Initial commit: hardened DeerFlow factory
Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection hardening: - New deerflow.security package: content_delimiter, html_cleaner, sanitizer (8 layers — invisible chars, control chars, symbols, NFC, PUA, tag chars, horizontal whitespace collapse with newline/tab preservation, length cap) - New deerflow.community.searx package: web_search, web_fetch, image_search backed by a private SearX instance, every external string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>> delimiters - All native community web providers (ddg_search, tavily, exa, firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail stubs that raise NativeWebToolDisabledError at import time, so a misconfigured tool.use path fails loud rather than silently falling back to unsanitized output - Native client back-doors (jina_client.py, infoquest_client.py) stubbed too - Native-tool tests quarantined under tests/_disabled_native/ (collect_ignore_glob via local conftest.py) - Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve newlines and tabs so list/table structure survives - Hardened runtime config.yaml references only the searx-backed tools - Factory overlay (backend/) kept in sync with deer-flow tree as a reference / source See HARDENING.md for the full audit trail and verification steps.
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from .tools import get_available_tools
|
||||
|
||||
__all__ = ["get_available_tools", "skill_manage_tool"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "skill_manage_tool":
|
||||
from .skill_manage_tool import skill_manage_tool
|
||||
|
||||
return skill_manage_tool
|
||||
raise AttributeError(name)
|
||||
@@ -0,0 +1,13 @@
|
||||
from .clarification_tool import ask_clarification_tool
|
||||
from .present_file_tool import present_file_tool
|
||||
from .setup_agent_tool import setup_agent
|
||||
from .task_tool import task_tool
|
||||
from .view_image_tool import view_image_tool
|
||||
|
||||
__all__ = [
|
||||
"setup_agent",
|
||||
"present_file_tool",
|
||||
"ask_clarification_tool",
|
||||
"view_image_tool",
|
||||
"task_tool",
|
||||
]
|
||||
@@ -0,0 +1,55 @@
|
||||
from typing import Literal
|
||||
|
||||
from langchain.tools import tool
|
||||
|
||||
|
||||
@tool("ask_clarification", parse_docstring=True, return_direct=True)
|
||||
def ask_clarification_tool(
|
||||
question: str,
|
||||
clarification_type: Literal[
|
||||
"missing_info",
|
||||
"ambiguous_requirement",
|
||||
"approach_choice",
|
||||
"risk_confirmation",
|
||||
"suggestion",
|
||||
],
|
||||
context: str | None = None,
|
||||
options: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Ask the user for clarification when you need more information to proceed.
|
||||
|
||||
Use this tool when you encounter situations where you cannot proceed without user input:
|
||||
|
||||
- **Missing information**: Required details not provided (e.g., file paths, URLs, specific requirements)
|
||||
- **Ambiguous requirements**: Multiple valid interpretations exist
|
||||
- **Approach choices**: Several valid approaches exist and you need user preference
|
||||
- **Risky operations**: Destructive actions that need explicit confirmation (e.g., deleting files, modifying production)
|
||||
- **Suggestions**: You have a recommendation but want user approval before proceeding
|
||||
|
||||
The execution will be interrupted and the question will be presented to the user.
|
||||
Wait for the user's response before continuing.
|
||||
|
||||
When to use ask_clarification:
|
||||
- You need information that wasn't provided in the user's request
|
||||
- The requirement can be interpreted in multiple ways
|
||||
- Multiple valid implementation approaches exist
|
||||
- You're about to perform a potentially dangerous operation
|
||||
- You have a recommendation but need user approval
|
||||
|
||||
Best practices:
|
||||
- Ask ONE clarification at a time for clarity
|
||||
- Be specific and clear in your question
|
||||
- Don't make assumptions when clarification is needed
|
||||
- For risky operations, ALWAYS ask for confirmation
|
||||
- After calling this tool, execution will be interrupted automatically
|
||||
|
||||
Args:
|
||||
question: The clarification question to ask the user. Be specific and clear.
|
||||
clarification_type: The type of clarification needed (missing_info, ambiguous_requirement, approach_choice, risk_confirmation, suggestion).
|
||||
context: Optional context explaining why clarification is needed. Helps the user understand the situation.
|
||||
options: Optional list of choices (for approach_choice or suggestion types). Present clear options for the user to choose from.
|
||||
"""
|
||||
# This is a placeholder implementation
|
||||
# The actual logic is handled by ClarificationMiddleware which intercepts this tool call
|
||||
# and interrupts execution to present the question to the user
|
||||
return "Clarification request processed by middleware"
|
||||
@@ -0,0 +1,256 @@
|
||||
"""Built-in tool for invoking external ACP-compatible agents."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Annotated, Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import BaseTool, InjectedToolArg, StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _InvokeACPAgentInput(BaseModel):
|
||||
agent: str = Field(description="Name of the ACP agent to invoke")
|
||||
prompt: str = Field(description="The concise task prompt to send to the agent")
|
||||
|
||||
|
||||
def _get_work_dir(thread_id: str | None) -> str:
|
||||
"""Get the per-thread ACP workspace directory.
|
||||
|
||||
Each thread gets an isolated workspace under
|
||||
``{base_dir}/threads/{thread_id}/acp-workspace/`` so that concurrent
|
||||
sessions cannot read or overwrite each other's ACP agent outputs.
|
||||
|
||||
Falls back to the legacy global ``{base_dir}/acp-workspace/`` when
|
||||
``thread_id`` is not available (e.g. embedded / direct invocation).
|
||||
|
||||
The directory is created automatically if it does not exist.
|
||||
|
||||
Returns:
|
||||
An absolute physical filesystem path to use as the working directory.
|
||||
"""
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
paths = get_paths()
|
||||
if thread_id:
|
||||
try:
|
||||
work_dir = paths.acp_workspace_dir(thread_id)
|
||||
except ValueError:
|
||||
logger.warning("Invalid thread_id %r for ACP workspace, falling back to global", thread_id)
|
||||
work_dir = paths.base_dir / "acp-workspace"
|
||||
else:
|
||||
work_dir = paths.base_dir / "acp-workspace"
|
||||
|
||||
work_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("ACP agent work_dir: %s", work_dir)
|
||||
return str(work_dir)
|
||||
|
||||
|
||||
def _build_mcp_servers() -> dict[str, dict[str, Any]]:
|
||||
"""Build ACP ``mcpServers`` config from DeerFlow's enabled MCP servers."""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.mcp.client import build_servers_config
|
||||
|
||||
return build_servers_config(ExtensionsConfig.from_file())
|
||||
|
||||
|
||||
def _build_acp_mcp_servers() -> list[dict[str, Any]]:
|
||||
"""Build ACP ``mcpServers`` payload for ``new_session``.
|
||||
|
||||
The ACP client expects a list of server objects, while DeerFlow's MCP helper
|
||||
returns a name -> config mapping for the LangChain MCP adapter. This helper
|
||||
converts the enabled servers into the ACP wire format.
|
||||
"""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
enabled_servers = extensions_config.get_enabled_mcp_servers()
|
||||
|
||||
mcp_servers: list[dict[str, Any]] = []
|
||||
for name, server_config in enabled_servers.items():
|
||||
transport_type = server_config.type or "stdio"
|
||||
payload: dict[str, Any] = {"name": name, "type": transport_type}
|
||||
|
||||
if transport_type == "stdio":
|
||||
if not server_config.command:
|
||||
raise ValueError(f"MCP server '{name}' with stdio transport requires 'command' field")
|
||||
payload["command"] = server_config.command
|
||||
payload["args"] = server_config.args
|
||||
payload["env"] = [{"name": key, "value": value} for key, value in server_config.env.items()]
|
||||
elif transport_type in ("http", "sse"):
|
||||
if not server_config.url:
|
||||
raise ValueError(f"MCP server '{name}' with {transport_type} transport requires 'url' field")
|
||||
payload["url"] = server_config.url
|
||||
payload["headers"] = [{"name": key, "value": value} for key, value in server_config.headers.items()]
|
||||
else:
|
||||
raise ValueError(f"MCP server '{name}' has unsupported transport type: {transport_type}")
|
||||
|
||||
mcp_servers.append(payload)
|
||||
|
||||
return mcp_servers
|
||||
|
||||
|
||||
def _build_permission_response(options: list[Any], *, auto_approve: bool) -> Any:
|
||||
"""Build an ACP permission response.
|
||||
|
||||
When ``auto_approve`` is True, selects the first ``allow_once`` (preferred)
|
||||
or ``allow_always`` option. When False (the default), always cancels —
|
||||
permission requests must be handled by the ACP agent's own policy or the
|
||||
agent must be configured to operate without requesting permissions.
|
||||
"""
|
||||
from acp import RequestPermissionResponse
|
||||
from acp.schema import AllowedOutcome, DeniedOutcome
|
||||
|
||||
if auto_approve:
|
||||
for preferred_kind in ("allow_once", "allow_always"):
|
||||
for option in options:
|
||||
if getattr(option, "kind", None) != preferred_kind:
|
||||
continue
|
||||
|
||||
option_id = getattr(option, "option_id", None)
|
||||
if option_id is None:
|
||||
option_id = getattr(option, "optionId", None)
|
||||
if option_id is None:
|
||||
continue
|
||||
|
||||
return RequestPermissionResponse(
|
||||
outcome=AllowedOutcome(outcome="selected", optionId=option_id),
|
||||
)
|
||||
|
||||
return RequestPermissionResponse(outcome=DeniedOutcome(outcome="cancelled"))
|
||||
|
||||
|
||||
def _format_invocation_error(agent: str, cmd: str, exc: Exception) -> str:
|
||||
"""Return a user-facing ACP invocation error with actionable remediation."""
|
||||
if not isinstance(exc, FileNotFoundError):
|
||||
return f"Error invoking ACP agent '{agent}': {exc}"
|
||||
|
||||
message = f"Error invoking ACP agent '{agent}': Command '{cmd}' was not found on PATH."
|
||||
if cmd == "codex-acp" and shutil.which("codex"):
|
||||
return f"{message} The installed `codex` CLI does not speak ACP directly. Install a Codex ACP adapter (for example `npx @zed-industries/codex-acp`) or update `acp_agents.codex.command` and `args` in config.yaml."
|
||||
|
||||
return f"{message} Install the agent binary or update `acp_agents.{agent}.command` in config.yaml."
|
||||
|
||||
|
||||
def build_invoke_acp_agent_tool(agents: dict) -> BaseTool:
|
||||
"""Create the ``invoke_acp_agent`` tool with a description generated from configured agents.
|
||||
|
||||
The tool description includes the list of available agents so that the LLM
|
||||
knows which agents it can invoke without requiring hardcoded names.
|
||||
|
||||
Args:
|
||||
agents: Mapping of agent name -> ``ACPAgentConfig``.
|
||||
|
||||
Returns:
|
||||
A LangChain ``BaseTool`` ready to be included in the tool list.
|
||||
"""
|
||||
agent_lines = "\n".join(f"- {name}: {cfg.description}" for name, cfg in agents.items())
|
||||
description = (
|
||||
"Invoke an external ACP-compatible agent and return its final response.\n\n"
|
||||
"Available agents:\n"
|
||||
f"{agent_lines}\n\n"
|
||||
"IMPORTANT: ACP agents operate in their own independent workspace. "
|
||||
"Do NOT include /mnt/user-data paths in the prompt. "
|
||||
"Give the agent a self-contained task description — it will produce results in its own workspace. "
|
||||
"After the agent completes, its output files are accessible at /mnt/acp-workspace/ (read-only)."
|
||||
)
|
||||
|
||||
# Capture agents in closure so the function can reference it
|
||||
_agents = dict(agents)
|
||||
|
||||
async def _invoke(agent: str, prompt: str, config: Annotated[RunnableConfig, InjectedToolArg] = None) -> str:
|
||||
logger.info("Invoking ACP agent %s (prompt length: %d)", agent, len(prompt))
|
||||
logger.debug("Invoking ACP agent %s with prompt: %.200s%s", agent, prompt, "..." if len(prompt) > 200 else "")
|
||||
if agent not in _agents:
|
||||
available = ", ".join(_agents.keys())
|
||||
return f"Error: Unknown agent '{agent}'. Available: {available}"
|
||||
|
||||
agent_config = _agents[agent]
|
||||
thread_id: str | None = ((config or {}).get("configurable") or {}).get("thread_id")
|
||||
|
||||
try:
|
||||
from acp import PROTOCOL_VERSION, Client, text_block
|
||||
from acp.schema import ClientCapabilities, Implementation
|
||||
except ImportError:
|
||||
return "Error: agent-client-protocol package is not installed. Run `uv sync` to install project dependencies."
|
||||
|
||||
class _CollectingClient(Client):
|
||||
"""Minimal ACP Client that collects streamed text from session updates."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._chunks: list[str] = []
|
||||
|
||||
@property
|
||||
def collected_text(self) -> str:
|
||||
return "".join(self._chunks)
|
||||
|
||||
async def session_update(self, session_id: str, update, **kwargs) -> None: # type: ignore[override]
|
||||
try:
|
||||
from acp.schema import TextContentBlock
|
||||
|
||||
if hasattr(update, "content") and isinstance(update.content, TextContentBlock):
|
||||
self._chunks.append(update.content.text)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def request_permission(self, options, session_id: str, tool_call, **kwargs): # type: ignore[override]
|
||||
response = _build_permission_response(options, auto_approve=agent_config.auto_approve_permissions)
|
||||
outcome = response.outcome.outcome
|
||||
if outcome == "selected":
|
||||
logger.info("ACP permission auto-approved for tool call %s in session %s", tool_call.tool_call_id, session_id)
|
||||
else:
|
||||
logger.warning("ACP permission denied for tool call %s in session %s (set auto_approve_permissions: true in config.yaml to enable)", tool_call.tool_call_id, session_id)
|
||||
return response
|
||||
|
||||
client = _CollectingClient()
|
||||
cmd = agent_config.command
|
||||
args = agent_config.args or []
|
||||
physical_cwd = _get_work_dir(thread_id)
|
||||
try:
|
||||
mcp_servers = _build_acp_mcp_servers()
|
||||
except ValueError as exc:
|
||||
logger.warning(
|
||||
"Invalid MCP server configuration for ACP agent '%s'; continuing without MCP servers: %s",
|
||||
agent,
|
||||
exc,
|
||||
)
|
||||
mcp_servers = []
|
||||
agent_env: dict[str, str] | None = None
|
||||
if agent_config.env:
|
||||
agent_env = {k: (os.environ.get(v[1:], "") if v.startswith("$") else v) for k, v in agent_config.env.items()}
|
||||
|
||||
try:
|
||||
from acp import spawn_agent_process
|
||||
|
||||
async with spawn_agent_process(client, cmd, *args, env=agent_env, cwd=physical_cwd) as (conn, proc):
|
||||
logger.info("Spawning ACP agent '%s' with command '%s' and args %s in cwd %s", agent, cmd, args, physical_cwd)
|
||||
await conn.initialize(
|
||||
protocol_version=PROTOCOL_VERSION,
|
||||
client_capabilities=ClientCapabilities(),
|
||||
client_info=Implementation(name="deerflow", title="DeerFlow", version="0.1.0"),
|
||||
)
|
||||
session_kwargs: dict[str, Any] = {"cwd": physical_cwd, "mcp_servers": mcp_servers}
|
||||
if agent_config.model:
|
||||
session_kwargs["model"] = agent_config.model
|
||||
session = await conn.new_session(**session_kwargs)
|
||||
await conn.prompt(
|
||||
session_id=session.session_id,
|
||||
prompt=[text_block(prompt)],
|
||||
)
|
||||
result = client.collected_text
|
||||
logger.info("ACP agent '%s' returned %s", agent, result[:1000])
|
||||
logger.info("ACP agent '%s' returned %d characters", agent, len(result))
|
||||
return result or "(no response)"
|
||||
except Exception as e:
|
||||
logger.error("ACP agent '%s' invocation failed: %s", agent, e)
|
||||
return _format_invocation_error(agent, cmd, e)
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="invoke_acp_agent",
|
||||
description=description,
|
||||
coroutine=_invoke,
|
||||
args_schema=_InvokeACPAgentInput,
|
||||
)
|
||||
@@ -0,0 +1,100 @@
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||
|
||||
OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs"
|
||||
|
||||
|
||||
def _normalize_presented_filepath(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
filepath: str,
|
||||
) -> str:
|
||||
"""Normalize a presented file path to the `/mnt/user-data/outputs/*` contract.
|
||||
|
||||
Accepts either:
|
||||
- A virtual sandbox path such as `/mnt/user-data/outputs/report.md`
|
||||
- A host-side thread outputs path such as
|
||||
`/app/backend/.deer-flow/threads/<thread>/user-data/outputs/report.md`
|
||||
|
||||
Returns:
|
||||
The normalized virtual path.
|
||||
|
||||
Raises:
|
||||
ValueError: If runtime metadata is missing or the path is outside the
|
||||
current thread's outputs directory.
|
||||
"""
|
||||
if runtime.state is None:
|
||||
raise ValueError("Thread runtime state is not available")
|
||||
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if not thread_id:
|
||||
raise ValueError("Thread ID is not available in runtime context")
|
||||
|
||||
thread_data = runtime.state.get("thread_data") or {}
|
||||
outputs_path = thread_data.get("outputs_path")
|
||||
if not outputs_path:
|
||||
raise ValueError("Thread outputs path is not available in runtime state")
|
||||
|
||||
outputs_dir = Path(outputs_path).resolve()
|
||||
stripped = filepath.lstrip("/")
|
||||
virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
|
||||
|
||||
if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"):
|
||||
actual_path = get_paths().resolve_virtual_path(thread_id, filepath)
|
||||
else:
|
||||
actual_path = Path(filepath).expanduser().resolve()
|
||||
|
||||
try:
|
||||
relative_path = actual_path.relative_to(outputs_dir)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Only files in {OUTPUTS_VIRTUAL_PREFIX} can be presented: {filepath}") from exc
|
||||
|
||||
return f"{OUTPUTS_VIRTUAL_PREFIX}/{relative_path.as_posix()}"
|
||||
|
||||
|
||||
@tool("present_files", parse_docstring=True)
|
||||
def present_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
filepaths: list[str],
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
"""Make files visible to the user for viewing and rendering in the client interface.
|
||||
|
||||
When to use the present_files tool:
|
||||
|
||||
- Making any file available for the user to view, download, or interact with
|
||||
- Presenting multiple related files at once
|
||||
- After creating files that should be presented to the user
|
||||
|
||||
When NOT to use the present_files tool:
|
||||
- When you only need to read file contents for your own processing
|
||||
- For temporary or intermediate files not meant for user viewing
|
||||
|
||||
Notes:
|
||||
- You should call this tool after creating files and moving them to the `/mnt/user-data/outputs` directory.
|
||||
- This tool can be safely called in parallel with other tools. State updates are handled by a reducer to prevent conflicts.
|
||||
|
||||
Args:
|
||||
filepaths: List of absolute file paths to present to the user. **Only** files in `/mnt/user-data/outputs` can be presented.
|
||||
"""
|
||||
try:
|
||||
normalized_paths = [_normalize_presented_filepath(runtime, filepath) for filepath in filepaths]
|
||||
except ValueError as exc:
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: {exc}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# The merge_artifacts reducer will handle merging and deduplication
|
||||
return Command(
|
||||
update={
|
||||
"artifacts": normalized_paths,
|
||||
"messages": [ToolMessage("Successfully presented files", tool_call_id=tool_call_id)],
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,62 @@
|
||||
import logging
|
||||
|
||||
import yaml
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.prebuilt import ToolRuntime
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool
|
||||
def setup_agent(
|
||||
soul: str,
|
||||
description: str,
|
||||
runtime: ToolRuntime,
|
||||
) -> Command:
|
||||
"""Setup the custom DeerFlow agent.
|
||||
|
||||
Args:
|
||||
soul: Full SOUL.md content defining the agent's personality and behavior.
|
||||
description: One-line description of what the agent does.
|
||||
"""
|
||||
|
||||
agent_name: str | None = runtime.context.get("agent_name") if runtime.context else None
|
||||
|
||||
try:
|
||||
paths = get_paths()
|
||||
agent_dir = paths.agent_dir(agent_name) if agent_name else paths.base_dir
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if agent_name:
|
||||
# If agent_name is provided, we are creating a custom agent in the agents/ directory
|
||||
config_data: dict = {"name": agent_name}
|
||||
if description:
|
||||
config_data["description"] = description
|
||||
|
||||
config_file = agent_dir / "config.yaml"
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
soul_file = agent_dir / "SOUL.md"
|
||||
soul_file.write_text(soul, encoding="utf-8")
|
||||
|
||||
logger.info(f"[agent_creator] Created agent '{agent_name}' at {agent_dir}")
|
||||
return Command(
|
||||
update={
|
||||
"created_agent_name": agent_name,
|
||||
"messages": [ToolMessage(content=f"Agent '{agent_name}' created successfully!", tool_call_id=runtime.tool_call_id)],
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import shutil
|
||||
|
||||
if agent_name and agent_dir.exists():
|
||||
# Cleanup the custom agent directory only if it was created but an error occurred during setup
|
||||
shutil.rmtree(agent_dir)
|
||||
logger.error(f"[agent_creator] Failed to create agent '{agent_name}': {e}", exc_info=True)
|
||||
return Command(update={"messages": [ToolMessage(content=f"Error: {e}", tool_call_id=runtime.tool_call_id)]})
|
||||
@@ -0,0 +1,248 @@
|
||||
"""Task tool for delegating work to subagents."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import replace
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langgraph.config import get_stream_writer
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import get_skills_prompt_section
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
|
||||
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
|
||||
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool("task", parse_docstring=True)
|
||||
async def task_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
prompt: str,
|
||||
subagent_type: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
max_turns: int | None = None,
|
||||
) -> str:
|
||||
"""Delegate a task to a specialized subagent that runs in its own context.
|
||||
|
||||
Subagents help you:
|
||||
- Preserve context by keeping exploration and implementation separate
|
||||
- Handle complex multi-step tasks autonomously
|
||||
- Execute commands or operations in isolated contexts
|
||||
|
||||
Available subagent types depend on the active sandbox configuration:
|
||||
- **general-purpose**: A capable agent for complex, multi-step tasks that require
|
||||
both exploration and action. Use when the task requires complex reasoning,
|
||||
multiple dependent steps, or would benefit from isolated context.
|
||||
- **bash**: Command execution specialist for running bash commands. This is only
|
||||
available when host bash is explicitly allowed or when using an isolated shell
|
||||
sandbox such as `AioSandboxProvider`.
|
||||
|
||||
When to use this tool:
|
||||
- Complex tasks requiring multiple steps or tools
|
||||
- Tasks that produce verbose output
|
||||
- When you want to isolate context from the main conversation
|
||||
- Parallel research or exploration tasks
|
||||
|
||||
When NOT to use this tool:
|
||||
- Simple, single-step operations (use tools directly)
|
||||
- Tasks requiring user interaction or clarification
|
||||
|
||||
Args:
|
||||
description: A short (3-5 word) description of the task for logging/display. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
prompt: The task description for the subagent. Be specific and clear about what needs to be done. ALWAYS PROVIDE THIS PARAMETER SECOND.
|
||||
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||
max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max.
|
||||
"""
|
||||
available_subagent_names = get_available_subagent_names()
|
||||
|
||||
# Get subagent configuration
|
||||
config = get_subagent_config(subagent_type)
|
||||
if config is None:
|
||||
available = ", ".join(available_subagent_names)
|
||||
return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}"
|
||||
if subagent_type == "bash" and not is_host_bash_allowed():
|
||||
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
|
||||
|
||||
# Build config overrides
|
||||
overrides: dict = {}
|
||||
|
||||
skills_section = get_skills_prompt_section()
|
||||
if skills_section:
|
||||
overrides["system_prompt"] = config.system_prompt + "\n\n" + skills_section
|
||||
|
||||
if max_turns is not None:
|
||||
overrides["max_turns"] = max_turns
|
||||
|
||||
if overrides:
|
||||
config = replace(config, **overrides)
|
||||
|
||||
# Extract parent context from runtime
|
||||
sandbox_state = None
|
||||
thread_data = None
|
||||
thread_id = None
|
||||
parent_model = None
|
||||
trace_id = None
|
||||
|
||||
if runtime is not None:
|
||||
sandbox_state = runtime.state.get("sandbox")
|
||||
thread_data = runtime.state.get("thread_data")
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id is None:
|
||||
thread_id = runtime.config.get("configurable", {}).get("thread_id")
|
||||
|
||||
# Try to get parent model from configurable
|
||||
metadata = runtime.config.get("metadata", {})
|
||||
parent_model = metadata.get("model_name")
|
||||
|
||||
# Get or generate trace_id for distributed tracing
|
||||
trace_id = metadata.get("trace_id") or str(uuid.uuid4())[:8]
|
||||
|
||||
# Get available tools (excluding task tool to prevent nesting)
|
||||
# Lazy import to avoid circular dependency
|
||||
from deerflow.tools import get_available_tools
|
||||
|
||||
# Subagents should not have subagent tools enabled (prevent recursive nesting)
|
||||
tools = get_available_tools(model_name=parent_model, subagent_enabled=False)
|
||||
|
||||
# Create executor
|
||||
executor = SubagentExecutor(
|
||||
config=config,
|
||||
tools=tools,
|
||||
parent_model=parent_model,
|
||||
sandbox_state=sandbox_state,
|
||||
thread_data=thread_data,
|
||||
thread_id=thread_id,
|
||||
trace_id=trace_id,
|
||||
)
|
||||
|
||||
# Start background execution (always async to prevent blocking)
|
||||
# Use tool_call_id as task_id for better traceability
|
||||
task_id = executor.execute_async(prompt, task_id=tool_call_id)
|
||||
|
||||
# Poll for task completion in backend (removes need for LLM to poll)
|
||||
poll_count = 0
|
||||
last_status = None
|
||||
last_message_count = 0 # Track how many AI messages we've already sent
|
||||
# Polling timeout: execution timeout + 60s buffer, checked every 5s
|
||||
max_poll_count = (config.timeout_seconds + 60) // 5
|
||||
|
||||
logger.info(f"[trace={trace_id}] Started background task {task_id} (subagent={subagent_type}, timeout={config.timeout_seconds}s, polling_limit={max_poll_count} polls)")
|
||||
|
||||
writer = get_stream_writer()
|
||||
# Send Task Started message'
|
||||
writer({"type": "task_started", "task_id": task_id, "description": description})
|
||||
|
||||
try:
|
||||
while True:
|
||||
result = get_background_task_result(task_id)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} not found in background tasks")
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": "Task disappeared from background tasks"})
|
||||
cleanup_background_task(task_id)
|
||||
return f"Error: Task {task_id} disappeared from background tasks"
|
||||
|
||||
# Log status changes for debugging
|
||||
if result.status != last_status:
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}")
|
||||
last_status = result.status
|
||||
|
||||
# Check for new AI messages and send task_running events
|
||||
current_message_count = len(result.ai_messages)
|
||||
if current_message_count > last_message_count:
|
||||
# Send task_running event for each new message
|
||||
for i in range(last_message_count, current_message_count):
|
||||
message = result.ai_messages[i]
|
||||
writer(
|
||||
{
|
||||
"type": "task_running",
|
||||
"task_id": task_id,
|
||||
"message": message,
|
||||
"message_index": i + 1, # 1-based index for display
|
||||
"total_messages": current_message_count,
|
||||
}
|
||||
)
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}")
|
||||
last_message_count = current_message_count
|
||||
|
||||
# Check if task completed, failed, or timed out
|
||||
if result.status == SubagentStatus.COMPLETED:
|
||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task Succeeded. Result: {result.result}"
|
||||
elif result.status == SubagentStatus.FAILED:
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task failed. Error: {result.error}"
|
||||
elif result.status == SubagentStatus.CANCELLED:
|
||||
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return "Task cancelled by user."
|
||||
elif result.status == SubagentStatus.TIMED_OUT:
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
||||
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task timed out. Error: {result.error}"
|
||||
|
||||
# Still running, wait before next poll
|
||||
await asyncio.sleep(5)
|
||||
poll_count += 1
|
||||
|
||||
# Polling timeout as a safety net (in case thread pool timeout doesn't work)
|
||||
# Set to execution timeout + 60s buffer, in 5s poll intervals
|
||||
# This catches edge cases where the background task gets stuck
|
||||
# Note: We don't call cleanup_background_task here because the task may
|
||||
# still be running in the background. The cleanup will happen when the
|
||||
# executor completes and sets a terminal status.
|
||||
if poll_count > max_poll_count:
|
||||
timeout_minutes = config.timeout_seconds // 60
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
||||
writer({"type": "task_timed_out", "task_id": task_id})
|
||||
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
||||
except asyncio.CancelledError:
|
||||
# Signal the background subagent thread to stop cooperatively.
|
||||
# Without this, the thread (running in ThreadPoolExecutor with its
|
||||
# own event loop via asyncio.run) would continue executing even
|
||||
# after the parent task is cancelled.
|
||||
request_cancel_background_task(task_id)
|
||||
|
||||
async def cleanup_when_done() -> None:
|
||||
max_cleanup_polls = max_poll_count
|
||||
cleanup_poll_count = 0
|
||||
|
||||
while True:
|
||||
result = get_background_task_result(task_id)
|
||||
if result is None:
|
||||
return
|
||||
|
||||
if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None:
|
||||
cleanup_background_task(task_id)
|
||||
return
|
||||
|
||||
if cleanup_poll_count > max_cleanup_polls:
|
||||
logger.warning(f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls")
|
||||
return
|
||||
|
||||
await asyncio.sleep(5)
|
||||
cleanup_poll_count += 1
|
||||
|
||||
def log_cleanup_failure(cleanup_task: asyncio.Task[None]) -> None:
|
||||
if cleanup_task.cancelled():
|
||||
return
|
||||
|
||||
exc = cleanup_task.exception()
|
||||
if exc is not None:
|
||||
logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}")
|
||||
|
||||
logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}")
|
||||
asyncio.create_task(cleanup_when_done()).add_done_callback(log_cleanup_failure)
|
||||
raise
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Tool search — deferred tool discovery at runtime.
|
||||
|
||||
Contains:
|
||||
- DeferredToolRegistry: stores deferred tools and handles regex search
|
||||
- tool_search: the LangChain tool the agent calls to discover deferred tools
|
||||
|
||||
The agent sees deferred tool names in <available-deferred-tools> but cannot
|
||||
call them until it fetches their full schema via the tool_search tool.
|
||||
Source-agnostic: no mention of MCP or tool origin.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_RESULTS = 5 # Max tools returned per search
|
||||
|
||||
|
||||
# ── Registry ──
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeferredToolEntry:
|
||||
"""Lightweight metadata for a deferred tool (no full schema in context)."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
tool: BaseTool # Full tool object, returned only on search match
|
||||
|
||||
|
||||
class DeferredToolRegistry:
|
||||
"""Registry of deferred tools, searchable by regex pattern."""
|
||||
|
||||
def __init__(self):
|
||||
self._entries: list[DeferredToolEntry] = []
|
||||
|
||||
def register(self, tool: BaseTool) -> None:
|
||||
self._entries.append(
|
||||
DeferredToolEntry(
|
||||
name=tool.name,
|
||||
description=tool.description or "",
|
||||
tool=tool,
|
||||
)
|
||||
)
|
||||
|
||||
def promote(self, names: set[str]) -> None:
|
||||
"""Remove tools from the deferred registry so they pass through the filter.
|
||||
|
||||
Called after tool_search returns a tool's schema — the LLM now knows
|
||||
the full definition, so the DeferredToolFilterMiddleware should stop
|
||||
stripping it from bind_tools on subsequent calls.
|
||||
"""
|
||||
if not names:
|
||||
return
|
||||
before = len(self._entries)
|
||||
self._entries = [e for e in self._entries if e.name not in names]
|
||||
promoted = before - len(self._entries)
|
||||
if promoted:
|
||||
logger.debug(f"Promoted {promoted} tool(s) from deferred to active: {names}")
|
||||
|
||||
def search(self, query: str) -> list[BaseTool]:
|
||||
"""Search deferred tools by regex pattern against name + description.
|
||||
|
||||
Supports three query forms (aligned with Claude Code):
|
||||
- "select:name1,name2" — exact name match
|
||||
- "+keyword rest" — name must contain keyword, rank by rest
|
||||
- "keyword query" — regex match against name + description
|
||||
|
||||
Returns:
|
||||
List of matched BaseTool objects (up to MAX_RESULTS).
|
||||
"""
|
||||
if query.startswith("select:"):
|
||||
names = {n.strip() for n in query[7:].split(",")}
|
||||
return [e.tool for e in self._entries if e.name in names][:MAX_RESULTS]
|
||||
|
||||
if query.startswith("+"):
|
||||
parts = query[1:].split(None, 1)
|
||||
required = parts[0].lower()
|
||||
candidates = [e for e in self._entries if required in e.name.lower()]
|
||||
if len(parts) > 1:
|
||||
candidates.sort(
|
||||
key=lambda e: _regex_score(parts[1], e),
|
||||
reverse=True,
|
||||
)
|
||||
return [e.tool for e in candidates][:MAX_RESULTS]
|
||||
|
||||
# General regex search
|
||||
try:
|
||||
regex = re.compile(query, re.IGNORECASE)
|
||||
except re.error:
|
||||
regex = re.compile(re.escape(query), re.IGNORECASE)
|
||||
|
||||
scored = []
|
||||
for entry in self._entries:
|
||||
searchable = f"{entry.name} {entry.description}"
|
||||
if regex.search(searchable):
|
||||
score = 2 if regex.search(entry.name) else 1
|
||||
scored.append((score, entry))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
return [entry.tool for _, entry in scored][:MAX_RESULTS]
|
||||
|
||||
@property
|
||||
def entries(self) -> list[DeferredToolEntry]:
|
||||
return list(self._entries)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._entries)
|
||||
|
||||
|
||||
def _regex_score(pattern: str, entry: DeferredToolEntry) -> int:
|
||||
try:
|
||||
regex = re.compile(pattern, re.IGNORECASE)
|
||||
except re.error:
|
||||
regex = re.compile(re.escape(pattern), re.IGNORECASE)
|
||||
return len(regex.findall(f"{entry.name} {entry.description}"))
|
||||
|
||||
|
||||
# ── Per-request registry (ContextVar) ──
|
||||
#
|
||||
# Using a ContextVar instead of a module-level global prevents concurrent
|
||||
# requests from clobbering each other's registry. In asyncio-based LangGraph
|
||||
# each graph run executes in its own async context, so each request gets an
|
||||
# independent registry value. For synchronous tools run via
|
||||
# loop.run_in_executor, Python copies the current context to the worker thread,
|
||||
# so the ContextVar value is correctly inherited there too.
|
||||
|
||||
_registry_var: contextvars.ContextVar[DeferredToolRegistry | None] = contextvars.ContextVar("deferred_tool_registry", default=None)
|
||||
|
||||
|
||||
def get_deferred_registry() -> DeferredToolRegistry | None:
|
||||
return _registry_var.get()
|
||||
|
||||
|
||||
def set_deferred_registry(registry: DeferredToolRegistry) -> None:
|
||||
_registry_var.set(registry)
|
||||
|
||||
|
||||
def reset_deferred_registry() -> None:
|
||||
"""Reset the deferred registry for the current async context."""
|
||||
_registry_var.set(None)
|
||||
|
||||
|
||||
# ── Tool ──
|
||||
|
||||
|
||||
@tool
|
||||
def tool_search(query: str) -> str:
|
||||
"""Fetches full schema definitions for deferred tools so they can be called.
|
||||
|
||||
Deferred tools appear by name in <available-deferred-tools> in the system
|
||||
prompt. Until fetched, only the name is known — there is no parameter
|
||||
schema, so the tool cannot be invoked. This tool takes a query, matches
|
||||
it against the deferred tool list, and returns the matched tools' complete
|
||||
definitions. Once a tool's schema appears in that result, it is callable.
|
||||
|
||||
Query forms:
|
||||
- "select:Read,Edit,Grep" — fetch these exact tools by name
|
||||
- "notebook jupyter" — keyword search, up to max_results best matches
|
||||
- "+slack send" — require "slack" in the name, rank by remaining terms
|
||||
|
||||
Args:
|
||||
query: Query to find deferred tools. Use "select:<tool_name>" for
|
||||
direct selection, or keywords to search.
|
||||
|
||||
Returns:
|
||||
Matched tool definitions as JSON array.
|
||||
"""
|
||||
registry = get_deferred_registry()
|
||||
if not registry:
|
||||
return "No deferred tools available."
|
||||
|
||||
matched_tools = registry.search(query)
|
||||
if not matched_tools:
|
||||
return f"No tools found matching: {query}"
|
||||
|
||||
# Use LangChain's built-in serialization to produce OpenAI function format.
|
||||
# This is model-agnostic: all LLMs understand this standard schema.
|
||||
tool_defs = [convert_to_openai_function(t) for t in matched_tools[:MAX_RESULTS]]
|
||||
|
||||
# Promote matched tools so the DeferredToolFilterMiddleware stops filtering
|
||||
# them from bind_tools — the LLM now has the full schema and can invoke them.
|
||||
registry.promote({t.name for t in matched_tools[:MAX_RESULTS]})
|
||||
|
||||
return json.dumps(tool_defs, indent=2, ensure_ascii=False)
|
||||
@@ -0,0 +1,95 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
|
||||
|
||||
@tool("view_image", parse_docstring=True)
|
||||
def view_image_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
image_path: str,
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
"""Read an image file.
|
||||
|
||||
Use this tool to read an image file and make it available for display.
|
||||
|
||||
When to use the view_image tool:
|
||||
- When you need to view an image file.
|
||||
|
||||
When NOT to use the view_image tool:
|
||||
- For non-image files (use present_files instead)
|
||||
- For multiple files at once (use present_files instead)
|
||||
|
||||
Args:
|
||||
image_path: Absolute path to the image file. Common formats supported: jpg, jpeg, png, webp.
|
||||
"""
|
||||
from deerflow.sandbox.tools import get_thread_data, replace_virtual_path
|
||||
|
||||
# Replace virtual path with actual path
|
||||
# /mnt/user-data/* paths are mapped to thread-specific directories
|
||||
thread_data = get_thread_data(runtime)
|
||||
actual_path = replace_virtual_path(image_path, thread_data)
|
||||
|
||||
# Validate that the path is absolute
|
||||
path = Path(actual_path)
|
||||
if not path.is_absolute():
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: Path must be absolute, got: {image_path}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Validate that the file exists
|
||||
if not path.exists():
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: Image file not found: {image_path}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Validate that it's a file (not a directory)
|
||||
if not path.is_file():
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: Path is not a file: {image_path}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Validate image extension
|
||||
valid_extensions = {".jpg", ".jpeg", ".png", ".webp"}
|
||||
if path.suffix.lower() not in valid_extensions:
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error: Unsupported image format: {path.suffix}. Supported formats: {', '.join(valid_extensions)}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Detect MIME type from file extension
|
||||
mime_type, _ = mimetypes.guess_type(actual_path)
|
||||
if mime_type is None:
|
||||
# Fallback to default MIME types for common image formats
|
||||
extension_to_mime = {
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".webp": "image/webp",
|
||||
}
|
||||
mime_type = extension_to_mime.get(path.suffix.lower(), "application/octet-stream")
|
||||
|
||||
# Read image file and convert to base64
|
||||
try:
|
||||
with open(actual_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
return Command(
|
||||
update={"messages": [ToolMessage(f"Error reading image file: {str(e)}", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
# Update viewed_images in state
|
||||
# The merge_viewed_images reducer will handle merging with existing images
|
||||
new_viewed_images = {image_path: {"base64": image_base64, "mime_type": mime_type}}
|
||||
|
||||
return Command(
|
||||
update={"viewed_images": new_viewed_images, "messages": [ToolMessage("Successfully read image", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
@@ -0,0 +1,247 @@
|
||||
"""Tool for creating and evolving custom skills."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import shutil
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from langchain.tools import ToolRuntime, tool
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.mcp.tools import _make_sync_tool_wrapper
|
||||
from deerflow.skills.manager import (
|
||||
append_history,
|
||||
atomic_write,
|
||||
custom_skill_exists,
|
||||
ensure_custom_skill_is_editable,
|
||||
ensure_safe_support_path,
|
||||
get_custom_skill_dir,
|
||||
get_custom_skill_file,
|
||||
public_skill_exists,
|
||||
read_custom_skill_content,
|
||||
validate_skill_markdown_content,
|
||||
validate_skill_name,
|
||||
)
|
||||
from deerflow.skills.security_scanner import scan_skill_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_skill_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
|
||||
|
||||
def _get_lock(name: str) -> asyncio.Lock:
|
||||
lock = _skill_locks.get(name)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_skill_locks[name] = lock
|
||||
return lock
|
||||
|
||||
|
||||
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None:
|
||||
if runtime is None:
|
||||
return None
|
||||
if runtime.context and runtime.context.get("thread_id"):
|
||||
return runtime.context.get("thread_id")
|
||||
return runtime.config.get("configurable", {}).get("thread_id")
|
||||
|
||||
|
||||
def _history_record(*, action: str, file_path: str, prev_content: str | None, new_content: str | None, thread_id: str | None, scanner: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"action": action,
|
||||
"author": "agent",
|
||||
"thread_id": thread_id,
|
||||
"file_path": file_path,
|
||||
"prev_content": prev_content,
|
||||
"new_content": new_content,
|
||||
"scanner": scanner,
|
||||
}
|
||||
|
||||
|
||||
async def _scan_or_raise(content: str, *, executable: bool, location: str) -> dict[str, str]:
|
||||
result = await scan_skill_content(content, executable=executable, location=location)
|
||||
if result.decision == "block":
|
||||
raise ValueError(f"Security scan blocked the write: {result.reason}")
|
||||
if executable and result.decision != "allow":
|
||||
raise ValueError(f"Security scan rejected executable content: {result.reason}")
|
||||
return {"decision": result.decision, "reason": result.reason}
|
||||
|
||||
|
||||
async def _to_thread(func, /, *args, **kwargs):
|
||||
return await asyncio.to_thread(func, *args, **kwargs)
|
||||
|
||||
|
||||
async def _skill_manage_impl(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
action: str,
|
||||
name: str,
|
||||
content: str | None = None,
|
||||
path: str | None = None,
|
||||
find: str | None = None,
|
||||
replace: str | None = None,
|
||||
expected_count: int | None = None,
|
||||
) -> str:
|
||||
"""Manage custom skills under skills/custom/.
|
||||
|
||||
Args:
|
||||
action: One of create, patch, edit, delete, write_file, remove_file.
|
||||
name: Skill name in hyphen-case.
|
||||
content: New file content for create, edit, or write_file.
|
||||
path: Supporting file path for write_file or remove_file.
|
||||
find: Existing text to replace for patch.
|
||||
replace: Replacement text for patch.
|
||||
expected_count: Optional expected number of replacements for patch.
|
||||
"""
|
||||
name = validate_skill_name(name)
|
||||
lock = _get_lock(name)
|
||||
thread_id = _get_thread_id(runtime)
|
||||
|
||||
async with lock:
|
||||
if action == "create":
|
||||
if await _to_thread(custom_skill_exists, name):
|
||||
raise ValueError(f"Custom skill '{name}' already exists.")
|
||||
if content is None:
|
||||
raise ValueError("content is required for create.")
|
||||
await _to_thread(validate_skill_markdown_content, name, content)
|
||||
scan = await _scan_or_raise(content, executable=False, location=f"{name}/SKILL.md")
|
||||
skill_file = await _to_thread(get_custom_skill_file, name)
|
||||
await _to_thread(atomic_write, skill_file, content)
|
||||
await _to_thread(
|
||||
append_history,
|
||||
name,
|
||||
_history_record(action="create", file_path="SKILL.md", prev_content=None, new_content=content, thread_id=thread_id, scanner=scan),
|
||||
)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
return f"Created custom skill '{name}'."
|
||||
|
||||
if action == "edit":
|
||||
await _to_thread(ensure_custom_skill_is_editable, name)
|
||||
if content is None:
|
||||
raise ValueError("content is required for edit.")
|
||||
await _to_thread(validate_skill_markdown_content, name, content)
|
||||
scan = await _scan_or_raise(content, executable=False, location=f"{name}/SKILL.md")
|
||||
skill_file = await _to_thread(get_custom_skill_file, name)
|
||||
prev_content = await _to_thread(skill_file.read_text, encoding="utf-8")
|
||||
await _to_thread(atomic_write, skill_file, content)
|
||||
await _to_thread(
|
||||
append_history,
|
||||
name,
|
||||
_history_record(action="edit", file_path="SKILL.md", prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan),
|
||||
)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
return f"Updated custom skill '{name}'."
|
||||
|
||||
if action == "patch":
|
||||
await _to_thread(ensure_custom_skill_is_editable, name)
|
||||
if find is None or replace is None:
|
||||
raise ValueError("find and replace are required for patch.")
|
||||
skill_file = await _to_thread(get_custom_skill_file, name)
|
||||
prev_content = await _to_thread(skill_file.read_text, encoding="utf-8")
|
||||
occurrences = prev_content.count(find)
|
||||
if occurrences == 0:
|
||||
raise ValueError("Patch target not found in SKILL.md.")
|
||||
if expected_count is not None and occurrences != expected_count:
|
||||
raise ValueError(f"Expected {expected_count} replacements but found {occurrences}.")
|
||||
replacement_count = expected_count if expected_count is not None else 1
|
||||
new_content = prev_content.replace(find, replace, replacement_count)
|
||||
await _to_thread(validate_skill_markdown_content, name, new_content)
|
||||
scan = await _scan_or_raise(new_content, executable=False, location=f"{name}/SKILL.md")
|
||||
await _to_thread(atomic_write, skill_file, new_content)
|
||||
await _to_thread(
|
||||
append_history,
|
||||
name,
|
||||
_history_record(action="patch", file_path="SKILL.md", prev_content=prev_content, new_content=new_content, thread_id=thread_id, scanner=scan),
|
||||
)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
return f"Patched custom skill '{name}' ({replacement_count} replacement(s) applied, {occurrences} match(es) found)."
|
||||
|
||||
if action == "delete":
|
||||
await _to_thread(ensure_custom_skill_is_editable, name)
|
||||
skill_dir = await _to_thread(get_custom_skill_dir, name)
|
||||
prev_content = await _to_thread(read_custom_skill_content, name)
|
||||
await _to_thread(
|
||||
append_history,
|
||||
name,
|
||||
_history_record(action="delete", file_path="SKILL.md", prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}),
|
||||
)
|
||||
await _to_thread(shutil.rmtree, skill_dir)
|
||||
await refresh_skills_system_prompt_cache_async()
|
||||
return f"Deleted custom skill '{name}'."
|
||||
|
||||
if action == "write_file":
|
||||
await _to_thread(ensure_custom_skill_is_editable, name)
|
||||
if path is None or content is None:
|
||||
raise ValueError("path and content are required for write_file.")
|
||||
target = await _to_thread(ensure_safe_support_path, name, path)
|
||||
exists = await _to_thread(target.exists)
|
||||
prev_content = await _to_thread(target.read_text, encoding="utf-8") if exists else None
|
||||
executable = "scripts/" in path or path.startswith("scripts/")
|
||||
scan = await _scan_or_raise(content, executable=executable, location=f"{name}/{path}")
|
||||
await _to_thread(atomic_write, target, content)
|
||||
await _to_thread(
|
||||
append_history,
|
||||
name,
|
||||
_history_record(action="write_file", file_path=path, prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan),
|
||||
)
|
||||
return f"Wrote '{path}' for custom skill '{name}'."
|
||||
|
||||
if action == "remove_file":
|
||||
await _to_thread(ensure_custom_skill_is_editable, name)
|
||||
if path is None:
|
||||
raise ValueError("path is required for remove_file.")
|
||||
target = await _to_thread(ensure_safe_support_path, name, path)
|
||||
if not await _to_thread(target.exists):
|
||||
raise FileNotFoundError(f"Supporting file '{path}' not found for skill '{name}'.")
|
||||
prev_content = await _to_thread(target.read_text, encoding="utf-8")
|
||||
await _to_thread(target.unlink)
|
||||
await _to_thread(
|
||||
append_history,
|
||||
name,
|
||||
_history_record(action="remove_file", file_path=path, prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}),
|
||||
)
|
||||
return f"Removed '{path}' from custom skill '{name}'."
|
||||
|
||||
if await _to_thread(public_skill_exists, name):
|
||||
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
|
||||
raise ValueError(f"Unsupported action '{action}'.")
|
||||
|
||||
|
||||
@tool("skill_manage", parse_docstring=True)
|
||||
async def skill_manage_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
action: str,
|
||||
name: str,
|
||||
content: str | None = None,
|
||||
path: str | None = None,
|
||||
find: str | None = None,
|
||||
replace: str | None = None,
|
||||
expected_count: int | None = None,
|
||||
) -> str:
|
||||
"""Manage custom skills under skills/custom/.
|
||||
|
||||
Args:
|
||||
action: One of create, patch, edit, delete, write_file, remove_file.
|
||||
name: Skill name in hyphen-case.
|
||||
content: New file content for create, edit, or write_file.
|
||||
path: Supporting file path for write_file or remove_file.
|
||||
find: Existing text to replace for patch.
|
||||
replace: Replacement text for patch.
|
||||
expected_count: Optional expected number of replacements for patch.
|
||||
"""
|
||||
return await _skill_manage_impl(
|
||||
runtime=runtime,
|
||||
action=action,
|
||||
name=name,
|
||||
content=content,
|
||||
path=path,
|
||||
find=find,
|
||||
replace=replace,
|
||||
expected_count=expected_count,
|
||||
)
|
||||
|
||||
|
||||
skill_manage_tool.func = _make_sync_tool_wrapper(_skill_manage_impl, "skill_manage")
|
||||
137
deer-flow/backend/packages/harness/deerflow/tools/tools.py
Normal file
137
deer-flow/backend/packages/harness/deerflow/tools/tools.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import logging
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.reflection import resolve_variable
|
||||
from deerflow.sandbox.security import is_host_bash_allowed
|
||||
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
|
||||
from deerflow.tools.builtins.tool_search import reset_deferred_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUILTIN_TOOLS = [
|
||||
present_file_tool,
|
||||
ask_clarification_tool,
|
||||
]
|
||||
|
||||
SUBAGENT_TOOLS = [
|
||||
task_tool,
|
||||
# task_status_tool is no longer exposed to LLM (backend handles polling internally)
|
||||
]
|
||||
|
||||
|
||||
def _is_host_bash_tool(tool: object) -> bool:
|
||||
"""Return True if the tool config represents a host-bash execution surface."""
|
||||
group = getattr(tool, "group", None)
|
||||
use = getattr(tool, "use", None)
|
||||
if group == "bash":
|
||||
return True
|
||||
if use == "deerflow.sandbox.tools:bash_tool":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_available_tools(
|
||||
groups: list[str] | None = None,
|
||||
include_mcp: bool = True,
|
||||
model_name: str | None = None,
|
||||
subagent_enabled: bool = False,
|
||||
) -> list[BaseTool]:
|
||||
"""Get all available tools from config.
|
||||
|
||||
Note: MCP tools should be initialized at application startup using
|
||||
`initialize_mcp_tools()` from deerflow.mcp module.
|
||||
|
||||
Args:
|
||||
groups: Optional list of tool groups to filter by.
|
||||
include_mcp: Whether to include tools from MCP servers (default: True).
|
||||
model_name: Optional model name to determine if vision tools should be included.
|
||||
subagent_enabled: Whether to include subagent tools (task, task_status).
|
||||
|
||||
Returns:
|
||||
List of available tools.
|
||||
"""
|
||||
config = get_app_config()
|
||||
tool_configs = [tool for tool in config.tools if groups is None or tool.group in groups]
|
||||
|
||||
# Do not expose host bash by default when LocalSandboxProvider is active.
|
||||
if not is_host_bash_allowed(config):
|
||||
tool_configs = [tool for tool in tool_configs if not _is_host_bash_tool(tool)]
|
||||
|
||||
loaded_tools = [resolve_variable(tool.use, BaseTool) for tool in tool_configs]
|
||||
|
||||
# Conditionally add tools based on config
|
||||
builtin_tools = BUILTIN_TOOLS.copy()
|
||||
skill_evolution_config = getattr(config, "skill_evolution", None)
|
||||
if getattr(skill_evolution_config, "enabled", False):
|
||||
from deerflow.tools.skill_manage_tool import skill_manage_tool
|
||||
|
||||
builtin_tools.append(skill_manage_tool)
|
||||
|
||||
# Add subagent tools only if enabled via runtime parameter
|
||||
if subagent_enabled:
|
||||
builtin_tools.extend(SUBAGENT_TOOLS)
|
||||
logger.info("Including subagent tools (task)")
|
||||
|
||||
# If no model_name specified, use the first model (default)
|
||||
if model_name is None and config.models:
|
||||
model_name = config.models[0].name
|
||||
|
||||
# Add view_image_tool only if the model supports vision
|
||||
model_config = config.get_model_config(model_name) if model_name else None
|
||||
if model_config is not None and model_config.supports_vision:
|
||||
builtin_tools.append(view_image_tool)
|
||||
logger.info(f"Including view_image_tool for model '{model_name}' (supports_vision=True)")
|
||||
|
||||
# Get cached MCP tools if enabled
|
||||
# NOTE: We use ExtensionsConfig.from_file() instead of config.extensions
|
||||
# to always read the latest configuration from disk. This ensures that changes
|
||||
# made through the Gateway API (which runs in a separate process) are immediately
|
||||
# reflected when loading MCP tools.
|
||||
mcp_tools = []
|
||||
# Reset deferred registry upfront to prevent stale state from previous calls
|
||||
reset_deferred_registry()
|
||||
if include_mcp:
|
||||
try:
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.mcp.cache import get_cached_mcp_tools
|
||||
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
if extensions_config.get_enabled_mcp_servers():
|
||||
mcp_tools = get_cached_mcp_tools()
|
||||
if mcp_tools:
|
||||
logger.info(f"Using {len(mcp_tools)} cached MCP tool(s)")
|
||||
|
||||
# When tool_search is enabled, register MCP tools in the
|
||||
# deferred registry and add tool_search to builtin tools.
|
||||
if config.tool_search.enabled:
|
||||
from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry
|
||||
from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool
|
||||
|
||||
registry = DeferredToolRegistry()
|
||||
for t in mcp_tools:
|
||||
registry.register(t)
|
||||
set_deferred_registry(registry)
|
||||
builtin_tools.append(tool_search_tool)
|
||||
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
|
||||
except ImportError:
|
||||
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cached MCP tools: {e}")
|
||||
|
||||
# Add invoke_acp_agent tool if any ACP agents are configured
|
||||
acp_tools: list[BaseTool] = []
|
||||
try:
|
||||
from deerflow.config.acp_config import get_acp_agents
|
||||
from deerflow.tools.builtins.invoke_acp_agent_tool import build_invoke_acp_agent_tool
|
||||
|
||||
acp_agents = get_acp_agents()
|
||||
if acp_agents:
|
||||
acp_tools.append(build_invoke_acp_agent_tool(acp_agents))
|
||||
logger.info(f"Including invoke_acp_agent tool ({len(acp_agents)} agent(s): {list(acp_agents.keys())})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ACP tool: {e}")
|
||||
|
||||
logger.info(f"Total tools loaded: {len(loaded_tools)}, built-in tools: {len(builtin_tools)}, MCP tools: {len(mcp_tools)}, ACP tools: {len(acp_tools)}")
|
||||
return loaded_tools + builtin_tools + mcp_tools + acp_tools
|
||||
Reference in New Issue
Block a user