Initial commit: hardened DeerFlow factory
Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection hardening: - New deerflow.security package: content_delimiter, html_cleaner, sanitizer (8 layers — invisible chars, control chars, symbols, NFC, PUA, tag chars, horizontal whitespace collapse with newline/tab preservation, length cap) - New deerflow.community.searx package: web_search, web_fetch, image_search backed by a private SearX instance, every external string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>> delimiters - All native community web providers (ddg_search, tavily, exa, firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail stubs that raise NativeWebToolDisabledError at import time, so a misconfigured tool.use path fails loud rather than silently falling back to unsanitized output - Native client back-doors (jina_client.py, infoquest_client.py) stubbed too - Native-tool tests quarantined under tests/_disabled_native/ (collect_ignore_glob via local conftest.py) - Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve newlines and tabs so list/table structure survives - Hardened runtime config.yaml references only the searx-backed tools - Factory overlay (backend/) kept in sync with deer-flow tree as a reference / source See HARDENING.md for the full audit trail and verification steps.
This commit is contained in:
14
deer-flow/backend/packages/harness/deerflow/mcp/__init__.py
Normal file
14
deer-flow/backend/packages/harness/deerflow/mcp/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""MCP (Model Context Protocol) integration using langchain-mcp-adapters."""
|
||||
|
||||
from .cache import get_cached_mcp_tools, initialize_mcp_tools, reset_mcp_tools_cache
|
||||
from .client import build_server_params, build_servers_config
|
||||
from .tools import get_mcp_tools
|
||||
|
||||
__all__ = [
|
||||
"build_server_params",
|
||||
"build_servers_config",
|
||||
"get_mcp_tools",
|
||||
"initialize_mcp_tools",
|
||||
"get_cached_mcp_tools",
|
||||
"reset_mcp_tools_cache",
|
||||
]
|
||||
138
deer-flow/backend/packages/harness/deerflow/mcp/cache.py
Normal file
138
deer-flow/backend/packages/harness/deerflow/mcp/cache.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Cache for MCP tools to avoid repeated loading."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_mcp_tools_cache: list[BaseTool] | None = None
|
||||
_cache_initialized = False
|
||||
_initialization_lock = asyncio.Lock()
|
||||
_config_mtime: float | None = None # Track config file modification time
|
||||
|
||||
|
||||
def _get_config_mtime() -> float | None:
|
||||
"""Get the modification time of the extensions config file.
|
||||
|
||||
Returns:
|
||||
The modification time as a float, or None if the file doesn't exist.
|
||||
"""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
|
||||
config_path = ExtensionsConfig.resolve_config_path()
|
||||
if config_path and config_path.exists():
|
||||
return os.path.getmtime(config_path)
|
||||
return None
|
||||
|
||||
|
||||
def _is_cache_stale() -> bool:
|
||||
"""Check if the cache is stale due to config file changes.
|
||||
|
||||
Returns:
|
||||
True if the cache should be invalidated, False otherwise.
|
||||
"""
|
||||
global _config_mtime
|
||||
|
||||
if not _cache_initialized:
|
||||
return False # Not initialized yet, not stale
|
||||
|
||||
current_mtime = _get_config_mtime()
|
||||
|
||||
# If we couldn't get mtime before or now, assume not stale
|
||||
if _config_mtime is None or current_mtime is None:
|
||||
return False
|
||||
|
||||
# If the config file has been modified since we cached, it's stale
|
||||
if current_mtime > _config_mtime:
|
||||
logger.info(f"MCP config file has been modified (mtime: {_config_mtime} -> {current_mtime}), cache is stale")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def initialize_mcp_tools() -> list[BaseTool]:
|
||||
"""Initialize and cache MCP tools.
|
||||
|
||||
This should be called once at application startup.
|
||||
|
||||
Returns:
|
||||
List of LangChain tools from all enabled MCP servers.
|
||||
"""
|
||||
global _mcp_tools_cache, _cache_initialized, _config_mtime
|
||||
|
||||
async with _initialization_lock:
|
||||
if _cache_initialized:
|
||||
logger.info("MCP tools already initialized")
|
||||
return _mcp_tools_cache or []
|
||||
|
||||
from deerflow.mcp.tools import get_mcp_tools
|
||||
|
||||
logger.info("Initializing MCP tools...")
|
||||
_mcp_tools_cache = await get_mcp_tools()
|
||||
_cache_initialized = True
|
||||
_config_mtime = _get_config_mtime() # Record config file mtime
|
||||
logger.info(f"MCP tools initialized: {len(_mcp_tools_cache)} tool(s) loaded (config mtime: {_config_mtime})")
|
||||
|
||||
return _mcp_tools_cache
|
||||
|
||||
|
||||
def get_cached_mcp_tools() -> list[BaseTool]:
|
||||
"""Get cached MCP tools with lazy initialization.
|
||||
|
||||
If tools are not initialized, automatically initializes them.
|
||||
This ensures MCP tools work in both FastAPI and LangGraph Studio contexts.
|
||||
|
||||
Also checks if the config file has been modified since last initialization,
|
||||
and re-initializes if needed. This ensures that changes made through the
|
||||
Gateway API (which runs in a separate process) are reflected in the
|
||||
LangGraph Server.
|
||||
|
||||
Returns:
|
||||
List of cached MCP tools.
|
||||
"""
|
||||
global _cache_initialized
|
||||
|
||||
# Check if cache is stale due to config file changes
|
||||
if _is_cache_stale():
|
||||
logger.info("MCP cache is stale, resetting for re-initialization...")
|
||||
reset_mcp_tools_cache()
|
||||
|
||||
if not _cache_initialized:
|
||||
logger.info("MCP tools not initialized, performing lazy initialization...")
|
||||
try:
|
||||
# Try to initialize in the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If loop is already running (e.g., in LangGraph Studio),
|
||||
# we need to create a new loop in a thread
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(asyncio.run, initialize_mcp_tools())
|
||||
future.result()
|
||||
else:
|
||||
# If no loop is running, we can use the current loop
|
||||
loop.run_until_complete(initialize_mcp_tools())
|
||||
except RuntimeError:
|
||||
# No event loop exists, create one
|
||||
asyncio.run(initialize_mcp_tools())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to lazy-initialize MCP tools: {e}")
|
||||
return []
|
||||
|
||||
return _mcp_tools_cache or []
|
||||
|
||||
|
||||
def reset_mcp_tools_cache() -> None:
|
||||
"""Reset the MCP tools cache.
|
||||
|
||||
This is useful for testing or when you want to reload MCP tools.
|
||||
"""
|
||||
global _mcp_tools_cache, _cache_initialized, _config_mtime
|
||||
_mcp_tools_cache = None
|
||||
_cache_initialized = False
|
||||
_config_mtime = None
|
||||
logger.info("MCP tools cache reset")
|
||||
68
deer-flow/backend/packages/harness/deerflow/mcp/client.py
Normal file
68
deer-flow/backend/packages/harness/deerflow/mcp/client.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""MCP client using langchain-mcp-adapters."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_server_params(server_name: str, config: McpServerConfig) -> dict[str, Any]:
|
||||
"""Build server parameters for MultiServerMCPClient.
|
||||
|
||||
Args:
|
||||
server_name: Name of the MCP server.
|
||||
config: Configuration for the MCP server.
|
||||
|
||||
Returns:
|
||||
Dictionary of server parameters for langchain-mcp-adapters.
|
||||
"""
|
||||
transport_type = config.type or "stdio"
|
||||
params: dict[str, Any] = {"transport": transport_type}
|
||||
|
||||
if transport_type == "stdio":
|
||||
if not config.command:
|
||||
raise ValueError(f"MCP server '{server_name}' with stdio transport requires 'command' field")
|
||||
params["command"] = config.command
|
||||
params["args"] = config.args
|
||||
# Add environment variables if present
|
||||
if config.env:
|
||||
params["env"] = config.env
|
||||
elif transport_type in ("sse", "http"):
|
||||
if not config.url:
|
||||
raise ValueError(f"MCP server '{server_name}' with {transport_type} transport requires 'url' field")
|
||||
params["url"] = config.url
|
||||
# Add headers if present
|
||||
if config.headers:
|
||||
params["headers"] = config.headers
|
||||
else:
|
||||
raise ValueError(f"MCP server '{server_name}' has unsupported transport type: {transport_type}")
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def build_servers_config(extensions_config: ExtensionsConfig) -> dict[str, dict[str, Any]]:
|
||||
"""Build servers configuration for MultiServerMCPClient.
|
||||
|
||||
Args:
|
||||
extensions_config: Extensions configuration containing all MCP servers.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping server names to their parameters.
|
||||
"""
|
||||
enabled_servers = extensions_config.get_enabled_mcp_servers()
|
||||
|
||||
if not enabled_servers:
|
||||
logger.info("No enabled MCP servers found")
|
||||
return {}
|
||||
|
||||
servers_config = {}
|
||||
for server_name, server_config in enabled_servers.items():
|
||||
try:
|
||||
servers_config[server_name] = build_server_params(server_name, server_config)
|
||||
logger.info(f"Configured MCP server: {server_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure MCP server '{server_name}': {e}")
|
||||
|
||||
return servers_config
|
||||
150
deer-flow/backend/packages/harness/deerflow/mcp/oauth.py
Normal file
150
deer-flow/backend/packages/harness/deerflow/mcp/oauth.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""OAuth token support for MCP HTTP/SSE servers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpOAuthConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OAuthToken:
|
||||
"""Cached OAuth token."""
|
||||
|
||||
access_token: str
|
||||
token_type: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
class OAuthTokenManager:
|
||||
"""Acquire/cache/refresh OAuth tokens for MCP servers."""
|
||||
|
||||
def __init__(self, oauth_by_server: dict[str, McpOAuthConfig]):
|
||||
self._oauth_by_server = oauth_by_server
|
||||
self._tokens: dict[str, _OAuthToken] = {}
|
||||
self._locks: dict[str, asyncio.Lock] = {name: asyncio.Lock() for name in oauth_by_server}
|
||||
|
||||
@classmethod
|
||||
def from_extensions_config(cls, extensions_config: ExtensionsConfig) -> OAuthTokenManager:
|
||||
oauth_by_server: dict[str, McpOAuthConfig] = {}
|
||||
for server_name, server_config in extensions_config.get_enabled_mcp_servers().items():
|
||||
if server_config.oauth and server_config.oauth.enabled:
|
||||
oauth_by_server[server_name] = server_config.oauth
|
||||
return cls(oauth_by_server)
|
||||
|
||||
def has_oauth_servers(self) -> bool:
|
||||
return bool(self._oauth_by_server)
|
||||
|
||||
def oauth_server_names(self) -> list[str]:
|
||||
return list(self._oauth_by_server.keys())
|
||||
|
||||
async def get_authorization_header(self, server_name: str) -> str | None:
|
||||
oauth = self._oauth_by_server.get(server_name)
|
||||
if not oauth:
|
||||
return None
|
||||
|
||||
token = self._tokens.get(server_name)
|
||||
if token and not self._is_expiring(token, oauth):
|
||||
return f"{token.token_type} {token.access_token}"
|
||||
|
||||
lock = self._locks[server_name]
|
||||
async with lock:
|
||||
token = self._tokens.get(server_name)
|
||||
if token and not self._is_expiring(token, oauth):
|
||||
return f"{token.token_type} {token.access_token}"
|
||||
|
||||
fresh = await self._fetch_token(oauth)
|
||||
self._tokens[server_name] = fresh
|
||||
logger.info(f"Refreshed OAuth access token for MCP server: {server_name}")
|
||||
return f"{fresh.token_type} {fresh.access_token}"
|
||||
|
||||
@staticmethod
|
||||
def _is_expiring(token: _OAuthToken, oauth: McpOAuthConfig) -> bool:
|
||||
now = datetime.now(UTC)
|
||||
return token.expires_at <= now + timedelta(seconds=max(oauth.refresh_skew_seconds, 0))
|
||||
|
||||
async def _fetch_token(self, oauth: McpOAuthConfig) -> _OAuthToken:
|
||||
import httpx # pyright: ignore[reportMissingImports]
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": oauth.grant_type,
|
||||
**oauth.extra_token_params,
|
||||
}
|
||||
|
||||
if oauth.scope:
|
||||
data["scope"] = oauth.scope
|
||||
if oauth.audience:
|
||||
data["audience"] = oauth.audience
|
||||
|
||||
if oauth.grant_type == "client_credentials":
|
||||
if not oauth.client_id or not oauth.client_secret:
|
||||
raise ValueError("OAuth client_credentials requires client_id and client_secret")
|
||||
data["client_id"] = oauth.client_id
|
||||
data["client_secret"] = oauth.client_secret
|
||||
elif oauth.grant_type == "refresh_token":
|
||||
if not oauth.refresh_token:
|
||||
raise ValueError("OAuth refresh_token grant requires refresh_token")
|
||||
data["refresh_token"] = oauth.refresh_token
|
||||
if oauth.client_id:
|
||||
data["client_id"] = oauth.client_id
|
||||
if oauth.client_secret:
|
||||
data["client_secret"] = oauth.client_secret
|
||||
else:
|
||||
raise ValueError(f"Unsupported OAuth grant type: {oauth.grant_type}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
response = await client.post(oauth.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
access_token = payload.get(oauth.token_field)
|
||||
if not access_token:
|
||||
raise ValueError(f"OAuth token response missing '{oauth.token_field}'")
|
||||
|
||||
token_type = str(payload.get(oauth.token_type_field, oauth.default_token_type) or oauth.default_token_type)
|
||||
|
||||
expires_in_raw = payload.get(oauth.expires_in_field, 3600)
|
||||
try:
|
||||
expires_in = int(expires_in_raw)
|
||||
except (TypeError, ValueError):
|
||||
expires_in = 3600
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=max(expires_in, 1))
|
||||
return _OAuthToken(access_token=access_token, token_type=token_type, expires_at=expires_at)
|
||||
|
||||
|
||||
def build_oauth_tool_interceptor(extensions_config: ExtensionsConfig) -> Any | None:
|
||||
"""Build a tool interceptor that injects OAuth Authorization headers."""
|
||||
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
|
||||
if not token_manager.has_oauth_servers():
|
||||
return None
|
||||
|
||||
async def oauth_interceptor(request: Any, handler: Any) -> Any:
|
||||
header = await token_manager.get_authorization_header(request.server_name)
|
||||
if not header:
|
||||
return await handler(request)
|
||||
|
||||
updated_headers = dict(request.headers or {})
|
||||
updated_headers["Authorization"] = header
|
||||
return await handler(request.override(headers=updated_headers))
|
||||
|
||||
return oauth_interceptor
|
||||
|
||||
|
||||
async def get_initial_oauth_headers(extensions_config: ExtensionsConfig) -> dict[str, str]:
|
||||
"""Get initial OAuth Authorization headers for MCP server connections."""
|
||||
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
|
||||
if not token_manager.has_oauth_servers():
|
||||
return {}
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
for server_name in token_manager.oauth_server_names():
|
||||
headers[server_name] = await token_manager.get_authorization_header(server_name) or ""
|
||||
|
||||
return {name: value for name, value in headers.items() if value}
|
||||
113
deer-flow/backend/packages/harness/deerflow/mcp/tools.py
Normal file
113
deer-flow/backend/packages/harness/deerflow/mcp/tools.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Load MCP tools using langchain-mcp-adapters."""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.mcp.client import build_servers_config
|
||||
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global thread pool for sync tool invocation in async environments
|
||||
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="mcp-sync-tool")
|
||||
|
||||
# Register shutdown hook for the global executor
|
||||
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
|
||||
|
||||
|
||||
def _make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
|
||||
"""Build a synchronous wrapper for an asynchronous tool coroutine.
|
||||
|
||||
Args:
|
||||
coro: The tool's asynchronous coroutine.
|
||||
tool_name: Name of the tool (for logging).
|
||||
|
||||
Returns:
|
||||
A synchronous function that correctly handles nested event loops.
|
||||
"""
|
||||
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
try:
|
||||
if loop is not None and loop.is_running():
|
||||
# Use global executor to avoid nested loop issues and improve performance
|
||||
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
|
||||
return future.result()
|
||||
else:
|
||||
return asyncio.run(coro(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.error(f"Error invoking MCP tool '{tool_name}' via sync wrapper: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
async def get_mcp_tools() -> list[BaseTool]:
|
||||
"""Get all tools from enabled MCP servers.
|
||||
|
||||
Returns:
|
||||
List of LangChain tools from all enabled MCP servers.
|
||||
"""
|
||||
try:
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
except ImportError:
|
||||
logger.warning("langchain-mcp-adapters not installed. Install it to enable MCP tools: pip install langchain-mcp-adapters")
|
||||
return []
|
||||
|
||||
# NOTE: We use ExtensionsConfig.from_file() instead of get_extensions_config()
|
||||
# to always read the latest configuration from disk. This ensures that changes
|
||||
# made through the Gateway API (which runs in a separate process) are immediately
|
||||
# reflected when initializing MCP tools.
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
servers_config = build_servers_config(extensions_config)
|
||||
|
||||
if not servers_config:
|
||||
logger.info("No enabled MCP servers configured")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Create the multi-server MCP client
|
||||
logger.info(f"Initializing MCP client with {len(servers_config)} server(s)")
|
||||
|
||||
# Inject initial OAuth headers for server connections (tool discovery/session init)
|
||||
initial_oauth_headers = await get_initial_oauth_headers(extensions_config)
|
||||
for server_name, auth_header in initial_oauth_headers.items():
|
||||
if server_name not in servers_config:
|
||||
continue
|
||||
if servers_config[server_name].get("transport") in ("sse", "http"):
|
||||
existing_headers = dict(servers_config[server_name].get("headers", {}))
|
||||
existing_headers["Authorization"] = auth_header
|
||||
servers_config[server_name]["headers"] = existing_headers
|
||||
|
||||
tool_interceptors = []
|
||||
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
|
||||
if oauth_interceptor is not None:
|
||||
tool_interceptors.append(oauth_interceptor)
|
||||
|
||||
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
|
||||
|
||||
# Get all tools from all servers
|
||||
tools = await client.get_tools()
|
||||
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
|
||||
|
||||
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||||
for tool in tools:
|
||||
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
|
||||
tool.func = _make_sync_tool_wrapper(tool.coroutine, tool.name)
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)
|
||||
return []
|
||||
Reference in New Issue
Block a user