Initial commit: hardened DeerFlow factory

Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection
hardening:

- New deerflow.security package: content_delimiter, html_cleaner,
  sanitizer (8 layers — invisible chars, control chars, symbols, NFC,
  PUA, tag chars, horizontal whitespace collapse with newline/tab
  preservation, length cap)
- New deerflow.community.searx package: web_search, web_fetch,
  image_search backed by a private SearX instance, every external
  string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>>
  delimiters
- All native community web providers (ddg_search, tavily, exa,
  firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail
  stubs that raise NativeWebToolDisabledError at import time, so a
  misconfigured tool.use path fails loud rather than silently falling
  back to unsanitized output
- Native client back-doors (jina_client.py, infoquest_client.py)
  stubbed too
- Native-tool tests quarantined under tests/_disabled_native/
  (collect_ignore_glob via local conftest.py)
- Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve
  newlines and tabs so list/table structure survives
- Hardened runtime config.yaml references only the searx-backed tools
- Factory overlay (backend/) kept in sync with deer-flow tree as a
  reference / source

See HARDENING.md for the full audit trail and verification steps.
This commit is contained in:
2026-04-12 14:23:57 +02:00
commit 6de0bf9f5b
889 changed files with 173052 additions and 0 deletions

View File

@@ -0,0 +1,191 @@
"""Middleware for intercepting clarification requests and presenting them to the user."""
import json
import logging
from collections.abc import Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.graph import END
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__)
class ClarificationMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
pass
class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
"""Intercepts clarification tool calls and interrupts execution to present questions to the user.
When the model calls the `ask_clarification` tool, this middleware:
1. Intercepts the tool call before execution
2. Extracts the clarification question and metadata
3. Formats a user-friendly message
4. Returns a Command that interrupts execution and presents the question
5. Waits for user response before continuing
This replaces the tool-based approach where clarification continued the conversation flow.
"""
state_schema = ClarificationMiddlewareState
def _is_chinese(self, text: str) -> bool:
"""Check if text contains Chinese characters.
Args:
text: Text to check
Returns:
True if text contains Chinese characters
"""
return any("\u4e00" <= char <= "\u9fff" for char in text)
def _format_clarification_message(self, args: dict) -> str:
"""Format the clarification arguments into a user-friendly message.
Args:
args: The tool call arguments containing clarification details
Returns:
Formatted message string
"""
question = args.get("question", "")
clarification_type = args.get("clarification_type", "missing_info")
context = args.get("context")
options = args.get("options", [])
# Some models (e.g. Qwen3-Max) serialize array parameters as JSON strings
# instead of native arrays. Deserialize and normalize so `options`
# is always a list for the rendering logic below.
if isinstance(options, str):
try:
options = json.loads(options)
except (json.JSONDecodeError, TypeError):
options = [options]
if options is None:
options = []
elif not isinstance(options, list):
options = [options]
# Type-specific icons
type_icons = {
"missing_info": "",
"ambiguous_requirement": "🤔",
"approach_choice": "🔀",
"risk_confirmation": "⚠️",
"suggestion": "💡",
}
icon = type_icons.get(clarification_type, "")
# Build the message naturally
message_parts = []
# Add icon and question together for a more natural flow
if context:
# If there's context, present it first as background
message_parts.append(f"{icon} {context}")
message_parts.append(f"\n{question}")
else:
# Just the question with icon
message_parts.append(f"{icon} {question}")
# Add options in a cleaner format
if options and len(options) > 0:
message_parts.append("") # blank line for spacing
for i, option in enumerate(options, 1):
message_parts.append(f" {i}. {option}")
return "\n".join(message_parts)
def _handle_clarification(self, request: ToolCallRequest) -> Command:
"""Handle clarification request and return command to interrupt execution.
Args:
request: Tool call request
Returns:
Command that interrupts execution with the formatted clarification message
"""
# Extract clarification arguments
args = request.tool_call.get("args", {})
question = args.get("question", "")
logger.info("Intercepted clarification request")
logger.debug("Clarification question: %s", question)
# Format the clarification message
formatted_message = self._format_clarification_message(args)
# Get the tool call ID
tool_call_id = request.tool_call.get("id", "")
# Create a ToolMessage with the formatted question
# This will be added to the message history
tool_message = ToolMessage(
content=formatted_message,
tool_call_id=tool_call_id,
name="ask_clarification",
)
# Return a Command that:
# 1. Adds the formatted tool message
# 2. Interrupts execution by going to __end__
# Note: We don't add an extra AIMessage here - the frontend will detect
# and display ask_clarification tool messages directly
return Command(
update={"messages": [tool_message]},
goto=END,
)
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept ask_clarification tool calls and interrupt execution (sync version).
Args:
request: Tool call request
handler: Original tool execution handler
Returns:
Command that interrupts execution with the formatted clarification message
"""
# Check if this is an ask_clarification tool call
if request.tool_call.get("name") != "ask_clarification":
# Not a clarification call, execute normally
return handler(request)
return self._handle_clarification(request)
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
"""Intercept ask_clarification tool calls and interrupt execution (async version).
Args:
request: Tool call request
handler: Original tool execution handler (async)
Returns:
Command that interrupts execution with the formatted clarification message
"""
# Check if this is an ask_clarification tool call
if request.tool_call.get("name") != "ask_clarification":
# Not a clarification call, execute normally
return await handler(request)
return self._handle_clarification(request)

View File

@@ -0,0 +1,110 @@
"""Middleware to fix dangling tool calls in message history.
A dangling tool call occurs when an AIMessage contains tool_calls but there are
no corresponding ToolMessages in the history (e.g., due to user interruption or
request cancellation). This causes LLM errors due to incomplete message format.
This middleware intercepts the model call to detect and patch such gaps by
inserting synthetic ToolMessages with an error indicator immediately after the
AIMessage that made the tool calls, ensuring correct message ordering.
Note: Uses wrap_model_call instead of before_model to ensure patches are inserted
at the correct positions (immediately after each dangling AIMessage), not appended
to the end of the message list as before_model + add_messages reducer would do.
"""
import logging
from collections.abc import Awaitable, Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import ToolMessage
logger = logging.getLogger(__name__)
class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
Scans the message history for AIMessages whose tool_calls lack corresponding
ToolMessages, and injects synthetic error responses immediately after the
offending AIMessage so the LLM receives a well-formed conversation.
"""
def _build_patched_messages(self, messages: list) -> list | None:
"""Return a new message list with patches inserted at the correct positions.
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
a synthetic ToolMessage is inserted immediately after that AIMessage.
Returns None if no patches are needed.
"""
# Collect IDs of all existing ToolMessages
existing_tool_msg_ids: set[str] = set()
for msg in messages:
if isinstance(msg, ToolMessage):
existing_tool_msg_ids.add(msg.tool_call_id)
# Check if any patching is needed
needs_patch = False
for msg in messages:
if getattr(msg, "type", None) != "ai":
continue
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids:
needs_patch = True
break
if needs_patch:
break
if not needs_patch:
return None
# Build new list with patches inserted right after each dangling AIMessage
patched: list = []
patched_ids: set[str] = set()
patch_count = 0
for msg in messages:
patched.append(msg)
if getattr(msg, "type", None) != "ai":
continue
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
patched.append(
ToolMessage(
content="[Tool call was interrupted and did not return a result.]",
tool_call_id=tc_id,
name=tc.get("name", "unknown"),
status="error",
)
)
patched_ids.add(tc_id)
patch_count += 1
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
return patched
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
patched = self._build_patched_messages(request.messages)
if patched is not None:
request = request.override(messages=patched)
return handler(request)
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
patched = self._build_patched_messages(request.messages)
if patched is not None:
request = request.override(messages=patched)
return await handler(request)

View File

@@ -0,0 +1,60 @@
"""Middleware to filter deferred tool schemas from model binding.
When tool_search is enabled, MCP tools are registered in the DeferredToolRegistry
and passed to ToolNode for execution, but their schemas should NOT be sent to the
LLM via bind_tools (that's the whole point of deferral — saving context tokens).
This middleware intercepts wrap_model_call and removes deferred tools from
request.tools so that model.bind_tools only receives active tool schemas.
The agent discovers deferred tools at runtime via the tool_search tool.
"""
import logging
from collections.abc import Awaitable, Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
logger = logging.getLogger(__name__)
class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
"""Remove deferred tools from request.tools before model binding.
ToolNode still holds all tools (including deferred) for execution routing,
but the LLM only sees active tool schemas — deferred tools are discoverable
via tool_search at runtime.
"""
def _filter_tools(self, request: ModelRequest) -> ModelRequest:
from deerflow.tools.builtins.tool_search import get_deferred_registry
registry = get_deferred_registry()
if not registry:
return request
deferred_names = {e.name for e in registry.entries}
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
if len(active_tools) < len(request.tools):
logger.debug(f"Filtered {len(request.tools) - len(active_tools)} deferred tool schema(s) from model binding")
return request.override(tools=active_tools)
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(self._filter_tools(request))
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await handler(self._filter_tools(request))

View File

@@ -0,0 +1,275 @@
"""LLM error handling middleware with retry/backoff and user-facing fallbacks."""
from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Awaitable, Callable
from email.utils import parsedate_to_datetime
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import (
ModelCallResult,
ModelRequest,
ModelResponse,
)
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
logger = logging.getLogger(__name__)
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
_BUSY_PATTERNS = (
"server busy",
"temporarily unavailable",
"try again later",
"please retry",
"please try again",
"overloaded",
"high demand",
"rate limit",
"负载较高",
"服务繁忙",
"稍后重试",
"请稍后重试",
)
_QUOTA_PATTERNS = (
"insufficient_quota",
"quota",
"billing",
"credit",
"payment",
"余额不足",
"超出限额",
"额度不足",
"欠费",
)
_AUTH_PATTERNS = (
"authentication",
"unauthorized",
"invalid api key",
"invalid_api_key",
"permission",
"forbidden",
"access denied",
"无权",
"未授权",
)
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"""Retry transient LLM errors and surface graceful assistant messages."""
retry_max_attempts: int = 3
retry_base_delay_ms: int = 1000
retry_cap_delay_ms: int = 8000
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
detail = _extract_error_detail(exc)
lowered = detail.lower()
error_code = _extract_error_code(exc)
status_code = _extract_status_code(exc)
if _matches_any(lowered, _QUOTA_PATTERNS) or _matches_any(str(error_code).lower(), _QUOTA_PATTERNS):
return False, "quota"
if _matches_any(lowered, _AUTH_PATTERNS):
return False, "auth"
exc_name = exc.__class__.__name__
if exc_name in {
"APITimeoutError",
"APIConnectionError",
"InternalServerError",
}:
return True, "transient"
if status_code in _RETRIABLE_STATUS_CODES:
return True, "transient"
if _matches_any(lowered, _BUSY_PATTERNS):
return True, "busy"
return False, "generic"
def _build_retry_delay_ms(self, attempt: int, exc: BaseException) -> int:
retry_after = _extract_retry_after_ms(exc)
if retry_after is not None:
return retry_after
backoff = self.retry_base_delay_ms * (2 ** max(0, attempt - 1))
return min(backoff, self.retry_cap_delay_ms)
def _build_retry_message(self, attempt: int, wait_ms: int, reason: str) -> str:
seconds = max(1, round(wait_ms / 1000))
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
def _build_user_message(self, exc: BaseException, reason: str) -> str:
detail = _extract_error_detail(exc)
if reason == "quota":
return "The configured LLM provider rejected the request because the account is out of quota, billing is unavailable, or usage is restricted. Please fix the provider account and try again."
if reason == "auth":
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
if reason in {"busy", "transient"}:
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
return f"LLM request failed: {detail}"
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
try:
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer(
{
"type": "llm_retry",
"attempt": attempt,
"max_attempts": self.retry_max_attempts,
"wait_ms": wait_ms,
"reason": reason,
"message": self._build_retry_message(attempt, wait_ms, reason),
}
)
except Exception:
logger.debug("Failed to emit llm_retry event", exc_info=True)
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
attempt = 1
while True:
try:
return handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
if retriable and attempt < self.retry_max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
attempt,
self.retry_max_attempts,
wait_ms,
_extract_error_detail(exc),
)
self._emit_retry_event(attempt, wait_ms, reason)
time.sleep(wait_ms / 1000)
attempt += 1
continue
logger.warning(
"LLM call failed after %d attempt(s): %s",
attempt,
_extract_error_detail(exc),
exc_info=exc,
)
return AIMessage(content=self._build_user_message(exc, reason))
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
attempt = 1
while True:
try:
return await handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
retriable, reason = self._classify_error(exc)
if retriable and attempt < self.retry_max_attempts:
wait_ms = self._build_retry_delay_ms(attempt, exc)
logger.warning(
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
attempt,
self.retry_max_attempts,
wait_ms,
_extract_error_detail(exc),
)
self._emit_retry_event(attempt, wait_ms, reason)
await asyncio.sleep(wait_ms / 1000)
attempt += 1
continue
logger.warning(
"LLM call failed after %d attempt(s): %s",
attempt,
_extract_error_detail(exc),
exc_info=exc,
)
return AIMessage(content=self._build_user_message(exc, reason))
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
return any(pattern in detail for pattern in patterns)
def _extract_error_code(exc: BaseException) -> Any:
for attr in ("code", "error_code"):
value = getattr(exc, attr, None)
if value not in (None, ""):
return value
body = getattr(exc, "body", None)
if isinstance(body, dict):
error = body.get("error")
if isinstance(error, dict):
for key in ("code", "type"):
value = error.get(key)
if value not in (None, ""):
return value
return None
def _extract_status_code(exc: BaseException) -> int | None:
for attr in ("status_code", "status"):
value = getattr(exc, attr, None)
if isinstance(value, int):
return value
response = getattr(exc, "response", None)
status = getattr(response, "status_code", None)
return status if isinstance(status, int) else None
def _extract_retry_after_ms(exc: BaseException) -> int | None:
response = getattr(exc, "response", None)
headers = getattr(response, "headers", None)
if headers is None:
return None
raw = None
header_name = ""
for key in ("retry-after-ms", "Retry-After-Ms", "retry-after", "Retry-After"):
header_name = key
if hasattr(headers, "get"):
raw = headers.get(key)
if raw:
break
if not raw:
return None
try:
multiplier = 1 if "ms" in header_name.lower() else 1000
return max(0, int(float(raw) * multiplier))
except (TypeError, ValueError):
try:
target = parsedate_to_datetime(str(raw))
delta = target.timestamp() - time.time()
return max(0, int(delta * 1000))
except (TypeError, ValueError, OverflowError):
return None
def _extract_error_detail(exc: BaseException) -> str:
detail = str(exc).strip()
if detail:
return detail
message = getattr(exc, "message", None)
if isinstance(message, str) and message.strip():
return message.strip()
return exc.__class__.__name__

View File

@@ -0,0 +1,372 @@
"""Middleware to detect and break repetitive tool call loops.
P0 safety: prevents the agent from calling the same tool with the same
arguments indefinitely until the recursion limit kills the run.
Detection strategy:
1. After each model response, hash the tool calls (name + args).
2. Track recent hashes in a sliding window.
3. If the same hash appears >= warn_threshold times, inject a
"you are repeating yourself — wrap up" system message (once per hash).
4. If it appears >= hard_limit times, strip all tool_calls from the
response so the agent is forced to produce a final text answer.
"""
import hashlib
import json
import logging
import threading
from collections import OrderedDict, defaultdict
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor
_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type
_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
"""Normalize tool call args to a dict plus an optional fallback key.
Some providers serialize ``args`` as a JSON string instead of a dict.
We defensively parse those cases so loop detection does not crash while
still preserving a stable fallback key for non-dict payloads.
"""
if isinstance(raw_args, dict):
return raw_args, None
if isinstance(raw_args, str):
try:
parsed = json.loads(raw_args)
except (TypeError, ValueError, json.JSONDecodeError):
return {}, raw_args
if isinstance(parsed, dict):
return parsed, None
return {}, json.dumps(parsed, sort_keys=True, default=str)
if raw_args is None:
return {}, None
return {}, json.dumps(raw_args, sort_keys=True, default=str)
def _stable_tool_key(name: str, args: dict, fallback_key: str | None) -> str:
"""Derive a stable key from salient args without overfitting to noise."""
if name == "read_file" and fallback_key is None:
path = args.get("path") or ""
start_line = args.get("start_line")
end_line = args.get("end_line")
bucket_size = 200
try:
start_line = int(start_line) if start_line is not None else 1
except (TypeError, ValueError):
start_line = 1
try:
end_line = int(end_line) if end_line is not None else start_line
except (TypeError, ValueError):
end_line = start_line
start_line, end_line = sorted((start_line, end_line))
bucket_start = max(start_line, 1)
bucket_end = max(end_line, 1)
bucket_start = (bucket_start - 1) // bucket_size
bucket_end = (bucket_end - 1) // bucket_size
return f"{path}:{bucket_start}-{bucket_end}"
# write_file / str_replace are content-sensitive: same path may be updated
# with different payloads during iteration. Using only salient fields (path)
# can collapse distinct calls, so we hash full args to reduce false positives.
if name in {"write_file", "str_replace"}:
if fallback_key is not None:
return fallback_key
return json.dumps(args, sort_keys=True, default=str)
salient_fields = ("path", "url", "query", "command", "pattern", "glob", "cmd")
stable_args = {field: args[field] for field in salient_fields if args.get(field) is not None}
if stable_args:
return json.dumps(stable_args, sort_keys=True, default=str)
if fallback_key is not None:
return fallback_key
return json.dumps(args, sort_keys=True, default=str)
def _hash_tool_calls(tool_calls: list[dict]) -> str:
"""Deterministic hash of a set of tool calls (name + stable key).
This is intended to be order-independent: the same multiset of tool calls
should always produce the same hash, regardless of their input order.
"""
# Normalize each tool call to a stable (name, key) structure.
normalized: list[str] = []
for tc in tool_calls:
name = tc.get("name", "")
args, fallback_key = _normalize_tool_call_args(tc.get("args", {}))
key = _stable_tool_key(name, args, fallback_key)
normalized.append(f"{name}:{key}")
# Sort so permutations of the same multiset of calls yield the same ordering.
normalized.sort()
blob = json.dumps(normalized, sort_keys=True, default=str)
return hashlib.md5(blob.encode()).hexdigest()[:12]
_WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
_TOOL_FREQ_WARNING_MSG = (
"[LOOP DETECTED] You have called {tool_name} {count} times without producing a final answer. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
)
_HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far."
_TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times — exceeded the per-tool safety limit. Producing final answer with results collected so far."
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Detects and breaks repetitive tool call loops.
Args:
warn_threshold: Number of identical tool call sets before injecting
a warning message. Default: 3.
hard_limit: Number of identical tool call sets before stripping
tool_calls entirely. Default: 5.
window_size: Size of the sliding window for tracking calls.
Default: 20.
max_tracked_threads: Maximum number of threads to track before
evicting the least recently used. Default: 100.
tool_freq_warn: Number of calls to the same tool *type* (regardless
of arguments) before injecting a frequency warning. Catches
cross-file read loops that hash-based detection misses.
Default: 30.
tool_freq_hard_limit: Number of calls to the same tool type before
forcing a stop. Default: 50.
"""
def __init__(
self,
warn_threshold: int = _DEFAULT_WARN_THRESHOLD,
hard_limit: int = _DEFAULT_HARD_LIMIT,
window_size: int = _DEFAULT_WINDOW_SIZE,
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
):
super().__init__()
self.warn_threshold = warn_threshold
self.hard_limit = hard_limit
self.window_size = window_size
self.max_tracked_threads = max_tracked_threads
self.tool_freq_warn = tool_freq_warn
self.tool_freq_hard_limit = tool_freq_hard_limit
self._lock = threading.Lock()
# Per-thread tracking using OrderedDict for LRU eviction
self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set)
# Per-thread, per-tool-type cumulative call counts
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id:
return thread_id
return "default"
def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit.
Must be called while holding self._lock.
"""
while len(self._history) > self.max_tracked_threads:
evicted_id, _ = self._history.popitem(last=False)
self._warned.pop(evicted_id, None)
self._tool_freq.pop(evicted_id, None)
self._tool_freq_warned.pop(evicted_id, None)
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
"""Track tool calls and check for loops.
Two detection layers:
1. **Hash-based** (existing): catches identical tool call sets.
2. **Frequency-based** (new): catches the same *tool type* being
called many times with varying arguments (e.g. ``read_file``
on 40 different files).
Returns:
(warning_message_or_none, should_hard_stop)
"""
messages = state.get("messages", [])
if not messages:
return None, False
last_msg = messages[-1]
if getattr(last_msg, "type", None) != "ai":
return None, False
tool_calls = getattr(last_msg, "tool_calls", None)
if not tool_calls:
return None, False
thread_id = self._get_thread_id(runtime)
call_hash = _hash_tool_calls(tool_calls)
with self._lock:
# Touch / create entry (move to end for LRU)
if thread_id in self._history:
self._history.move_to_end(thread_id)
else:
self._history[thread_id] = []
self._evict_if_needed()
history = self._history[thread_id]
history.append(call_hash)
if len(history) > self.window_size:
history[:] = history[-self.window_size :]
count = history.count(call_hash)
tool_names = [tc.get("name", "?") for tc in tool_calls]
# --- Layer 1: hash-based (identical call sets) ---
if count >= self.hard_limit:
logger.error(
"Loop hard limit reached — forcing stop",
extra={
"thread_id": thread_id,
"call_hash": call_hash,
"count": count,
"tools": tool_names,
},
)
return _HARD_STOP_MSG, True
if count >= self.warn_threshold:
warned = self._warned[thread_id]
if call_hash not in warned:
warned.add(call_hash)
logger.warning(
"Repetitive tool calls detected — injecting warning",
extra={
"thread_id": thread_id,
"call_hash": call_hash,
"count": count,
"tools": tool_names,
},
)
return _WARNING_MSG, False
# --- Layer 2: per-tool-type frequency ---
freq = self._tool_freq[thread_id]
for tc in tool_calls:
name = tc.get("name", "")
if not name:
continue
freq[name] += 1
tc_count = freq[name]
if tc_count >= self.tool_freq_hard_limit:
logger.error(
"Tool frequency hard limit reached — forcing stop",
extra={
"thread_id": thread_id,
"tool_name": name,
"count": tc_count,
},
)
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
if tc_count >= self.tool_freq_warn:
warned = self._tool_freq_warned[thread_id]
if name not in warned:
warned.add(name)
logger.warning(
"Tool frequency warning — too many calls to same tool type",
extra={
"thread_id": thread_id,
"tool_name": name,
"count": tc_count,
},
)
return _TOOL_FREQ_WARNING_MSG.format(tool_name=name, count=tc_count), False
return None, False
@staticmethod
def _append_text(content: str | list | None, text: str) -> str | list:
"""Append *text* to AIMessage content, handling str, list, and None.
When content is a list of content blocks (e.g. Anthropic thinking mode),
we append a new ``{"type": "text", ...}`` block instead of concatenating
a string to a list, which would raise ``TypeError``.
"""
if content is None:
return text
if isinstance(content, list):
return [*content, {"type": "text", "text": f"\n\n{text}"}]
if isinstance(content, str):
return content + f"\n\n{text}"
# Fallback: coerce unexpected types to str to avoid TypeError
return str(content) + f"\n\n{text}"
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
warning, hard_stop = self._track_and_check(state, runtime)
if hard_stop:
# Strip tool_calls from the last AIMessage to force text output
messages = state.get("messages", [])
last_msg = messages[-1]
stripped_msg = last_msg.model_copy(
update={
"tool_calls": [],
"content": self._append_text(last_msg.content, warning),
}
)
return {"messages": [stripped_msg]}
if warning:
# Inject as HumanMessage instead of SystemMessage to avoid
# Anthropic's "multiple non-consecutive system messages" error.
# Anthropic models require system messages only at the start of
# the conversation; injecting one mid-conversation crashes
# langchain_anthropic's _format_messages(). HumanMessage works
# with all providers. See #1299.
return {"messages": [HumanMessage(content=warning)]}
return None
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
def reset(self, thread_id: str | None = None) -> None:
"""Clear tracking state. If thread_id given, clear only that thread."""
with self._lock:
if thread_id:
self._history.pop(thread_id, None)
self._warned.pop(thread_id, None)
self._tool_freq.pop(thread_id, None)
self._tool_freq_warned.pop(thread_id, None)
else:
self._history.clear()
self._warned.clear()
self._tool_freq.clear()
self._tool_freq_warned.clear()

View File

@@ -0,0 +1,248 @@
"""Middleware for memory mechanism."""
import logging
import re
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config
logger = logging.getLogger(__name__)
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
_CORRECTION_PATTERNS = (
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
re.compile(r"\btry again\b", re.IGNORECASE),
re.compile(r"\bredo\b", re.IGNORECASE),
re.compile(r"不对"),
re.compile(r"你理解错了"),
re.compile(r"你理解有误"),
re.compile(r"重试"),
re.compile(r"重新来"),
re.compile(r"换一种"),
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
pass
def _extract_message_text(message: Any) -> str:
"""Extract plain text from message content for filtering and signal detection."""
content = getattr(message, "content", "")
if isinstance(content, list):
text_parts: list[str] = []
for part in content:
if isinstance(part, str):
text_parts.append(part)
elif isinstance(part, dict):
text_val = part.get("text")
if isinstance(text_val, str):
text_parts.append(text_val)
return " ".join(text_parts)
return str(content)
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
"""Filter messages to keep only user inputs and final assistant responses.
This filters out:
- Tool messages (intermediate tool call results)
- AI messages with tool_calls (intermediate steps, not final responses)
- The <uploaded_files> block injected by UploadsMiddleware into human messages
(file paths are session-scoped and must not persist in long-term memory).
The user's actual question is preserved; only turns whose content is entirely
the upload block (nothing remains after stripping) are dropped along with
their paired assistant response.
Only keeps:
- Human messages (with the ephemeral upload block removed)
- AI messages without tool_calls (final assistant responses), unless the
paired human turn was upload-only and had no real user text.
Args:
messages: List of all conversation messages.
Returns:
Filtered list containing only user inputs and final assistant responses.
"""
filtered = []
skip_next_ai = False
for msg in messages:
msg_type = getattr(msg, "type", None)
if msg_type == "human":
content_str = _extract_message_text(msg)
if "<uploaded_files>" in content_str:
# Strip the ephemeral upload block; keep the user's real question.
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
if not stripped:
# Nothing left — the entire turn was upload bookkeeping;
# skip it and the paired assistant response.
skip_next_ai = True
continue
# Rebuild the message with cleaned content so the user's question
# is still available for memory summarisation.
from copy import copy
clean_msg = copy(msg)
clean_msg.content = stripped
filtered.append(clean_msg)
skip_next_ai = False
else:
filtered.append(msg)
skip_next_ai = False
elif msg_type == "ai":
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
if skip_next_ai:
skip_next_ai = False
continue
filtered.append(msg)
# Skip tool messages and AI messages with tool_calls
return filtered
def detect_correction(messages: list[Any]) -> bool:
"""Detect explicit user corrections in recent conversation turns.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale corrections from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
return True
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns.
Complements detect_correction() by identifying when the user confirms the
agent's approach was correct. This allows the memory system to record what
worked well, not just what went wrong.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale signals from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution.
This middleware:
1. After each agent execution, queues the conversation for memory update
2. Only includes user inputs and final assistant responses (ignores tool calls)
3. The queue uses debouncing to batch multiple updates together
4. Memory is updated asynchronously via LLM summarization
"""
state_schema = MemoryMiddlewareState
def __init__(self, agent_name: str | None = None):
"""Initialize the MemoryMiddleware.
Args:
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
"""
super().__init__()
self._agent_name = agent_name
@override
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
"""Queue conversation for memory update after agent completes.
Args:
state: The current agent state.
runtime: The runtime context.
Returns:
None (no state changes needed from this middleware).
"""
config = get_memory_config()
if not config.enabled:
return None
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
config_data = get_config()
thread_id = config_data.get("configurable", {}).get("thread_id")
if not thread_id:
logger.debug("No thread_id in context, skipping memory update")
return None
# Get messages from state
messages = state.get("messages", [])
if not messages:
logger.debug("No messages in state, skipping memory update")
return None
# Filter to only keep user inputs and final assistant responses
filtered_messages = _filter_messages_for_memory(messages)
# Only queue if there's meaningful conversation
# At minimum need one user message and one assistant response
user_messages = [m for m in filtered_messages if getattr(m, "type", None) == "human"]
assistant_messages = [m for m in filtered_messages if getattr(m, "type", None) == "ai"]
if not user_messages or not assistant_messages:
return None
# Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue()
queue.add(
thread_id=thread_id,
messages=filtered_messages,
agent_name=self._agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
return None

View File

@@ -0,0 +1,363 @@
"""SandboxAuditMiddleware - bash command security auditing."""
import json
import logging
import re
import shlex
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from typing import override
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from deerflow.agents.thread_state import ThreadState
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Command classification rules
# ---------------------------------------------------------------------------
# Each pattern is compiled once at import time.
_HIGH_RISK_PATTERNS: list[re.Pattern[str]] = [
# --- original rules (retained) ---
re.compile(r"rm\s+-[^\s]*r[^\s]*\s+(/\*?|~/?\*?|/home\b|/root\b)\s*$"),
re.compile(r"dd\s+if="),
re.compile(r"mkfs"),
re.compile(r"cat\s+/etc/shadow"),
re.compile(r">+\s*/etc/"),
# --- pipe to sh/bash (generalised, replaces old curl|sh rule) ---
re.compile(r"\|\s*(ba)?sh\b"),
# --- command substitution (targeted only dangerous executables) ---
re.compile(r"[`$]\(?\s*(curl|wget|bash|sh|python|ruby|perl|base64)"),
# --- base64 decode piped to execution ---
re.compile(r"base64\s+.*-d.*\|"),
# --- overwrite system binaries ---
re.compile(r">+\s*(/usr/bin/|/bin/|/sbin/)"),
# --- overwrite shell startup files ---
re.compile(r">+\s*~/?\.(bashrc|profile|zshrc|bash_profile)"),
# --- process environment leakage ---
re.compile(r"/proc/[^/]+/environ"),
# --- dynamic linker hijack (one-step escalation) ---
re.compile(r"\b(LD_PRELOAD|LD_LIBRARY_PATH)\s*="),
# --- bash built-in networking (bypasses tool allowlists) ---
re.compile(r"/dev/tcp/"),
# --- fork bomb ---
re.compile(r"\S+\(\)\s*\{[^}]*\|\s*\S+\s*&"), # :(){ :|:& };:
re.compile(r"while\s+true.*&\s*done"), # while true; do bash & done
]
_MEDIUM_RISK_PATTERNS: list[re.Pattern[str]] = [
re.compile(r"chmod\s+777"),
re.compile(r"pip3?\s+install"),
re.compile(r"apt(-get)?\s+install"),
# sudo/su: no-op under Docker root; warn so LLM is aware
re.compile(r"\b(sudo|su)\b"),
# PATH modification: long attack chain, warn rather than block
re.compile(r"\bPATH\s*="),
]
def _split_compound_command(command: str) -> list[str]:
"""Split a compound command into sub-commands (quote-aware).
Scans the raw command string so unquoted shell control operators are
recognised even when they are not surrounded by whitespace
(e.g. ``safe;rm -rf /`` or ``rm -rf /&&echo ok``). Operators inside
quotes are ignored. If the command ends with an unclosed quote or a
dangling escape, return the whole command unchanged (fail-closed —
safer to classify the unsplit string than silently drop parts).
"""
parts: list[str] = []
current: list[str] = []
in_single_quote = False
in_double_quote = False
escaping = False
index = 0
while index < len(command):
char = command[index]
if escaping:
current.append(char)
escaping = False
index += 1
continue
if char == "\\" and not in_single_quote:
current.append(char)
escaping = True
index += 1
continue
if char == "'" and not in_double_quote:
in_single_quote = not in_single_quote
current.append(char)
index += 1
continue
if char == '"' and not in_single_quote:
in_double_quote = not in_double_quote
current.append(char)
index += 1
continue
if not in_single_quote and not in_double_quote:
if command.startswith("&&", index) or command.startswith("||", index):
part = "".join(current).strip()
if part:
parts.append(part)
current = []
index += 2
continue
if char == ";":
part = "".join(current).strip()
if part:
parts.append(part)
current = []
index += 1
continue
current.append(char)
index += 1
# Unclosed quote or dangling escape → fail-closed, return whole command
if in_single_quote or in_double_quote or escaping:
return [command]
part = "".join(current).strip()
if part:
parts.append(part)
return parts if parts else [command]
def _classify_single_command(command: str) -> str:
"""Classify a single (non-compound) command. Return 'block', 'warn', or 'pass'."""
normalized = " ".join(command.split())
for pattern in _HIGH_RISK_PATTERNS:
if pattern.search(normalized):
return "block"
# Also try shlex-parsed tokens for high-risk detection
try:
tokens = shlex.split(command)
joined = " ".join(tokens)
for pattern in _HIGH_RISK_PATTERNS:
if pattern.search(joined):
return "block"
except ValueError:
# shlex.split fails on unclosed quotes — treat as suspicious
return "block"
for pattern in _MEDIUM_RISK_PATTERNS:
if pattern.search(normalized):
return "warn"
return "pass"
def _classify_command(command: str) -> str:
"""Return 'block', 'warn', or 'pass'.
Strategy:
1. First scan the *whole* raw command against high-risk patterns. This
catches structural attacks like ``while true; do bash & done`` or
``:(){ :|:& };:`` that span multiple shell statements — splitting them
on ``;`` would destroy the pattern context.
2. Then split compound commands (e.g. ``cmd1 && cmd2 ; cmd3``) and
classify each sub-command independently. The most severe verdict wins.
"""
# Pass 1: whole-command high-risk scan (catches multi-statement patterns)
normalized = " ".join(command.split())
for pattern in _HIGH_RISK_PATTERNS:
if pattern.search(normalized):
return "block"
# Pass 2: per-sub-command classification
sub_commands = _split_compound_command(command)
worst = "pass"
for sub in sub_commands:
verdict = _classify_single_command(sub)
if verdict == "block":
return "block" # short-circuit: can't get worse
if verdict == "warn":
worst = "warn"
return worst
# ---------------------------------------------------------------------------
# Middleware
# ---------------------------------------------------------------------------
class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
"""Bash command security auditing middleware.
For every ``bash`` tool call:
1. **Command classification**: regex + shlex analysis grades commands as
high-risk (block), medium-risk (warn), or safe (pass).
2. **Audit log**: every bash call is recorded as a structured JSON entry
via the standard logger (visible in langgraph.log).
High-risk commands (e.g. ``rm -rf /``, ``curl url | bash``) are blocked:
the handler is not called and an error ``ToolMessage`` is returned so the
agent loop can continue gracefully.
Medium-risk commands (e.g. ``pip install``, ``chmod 777``) are executed
normally; a warning is appended to the tool result so the LLM is aware.
"""
state_schema = ThreadState
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _get_thread_id(self, request: ToolCallRequest) -> str | None:
runtime = request.runtime # ToolRuntime; may be None-like in tests
if runtime is None:
return None
ctx = getattr(runtime, "context", None) or {}
thread_id = ctx.get("thread_id") if isinstance(ctx, dict) else None
if thread_id is None:
cfg = getattr(runtime, "config", None) or {}
thread_id = cfg.get("configurable", {}).get("thread_id")
return thread_id
_AUDIT_COMMAND_LIMIT = 200
def _write_audit(self, thread_id: str | None, command: str, verdict: str, *, truncate: bool = False) -> None:
audited_command = command
if truncate and len(command) > self._AUDIT_COMMAND_LIMIT:
audited_command = f"{command[: self._AUDIT_COMMAND_LIMIT]}... ({len(command)} chars)"
record = {
"timestamp": datetime.now(UTC).isoformat(),
"thread_id": thread_id or "unknown",
"command": audited_command,
"verdict": verdict,
}
logger.info("[SandboxAudit] %s", json.dumps(record, ensure_ascii=False))
def _build_block_message(self, request: ToolCallRequest, reason: str) -> ToolMessage:
tool_call_id = str(request.tool_call.get("id") or "missing_id")
return ToolMessage(
content=f"Command blocked: {reason}. Please use a safer alternative approach.",
tool_call_id=tool_call_id,
name="bash",
status="error",
)
def _append_warn_to_result(self, result: ToolMessage | Command, command: str) -> ToolMessage | Command:
"""Append a warning note to the tool result for medium-risk commands."""
if not isinstance(result, ToolMessage):
return result
warning = f"\n\n⚠️ Warning: `{command}` is a medium-risk command that may modify the runtime environment."
if isinstance(result.content, list):
new_content = list(result.content) + [{"type": "text", "text": warning}]
else:
new_content = str(result.content) + warning
return ToolMessage(
content=new_content,
tool_call_id=result.tool_call_id,
name=result.name,
status=result.status,
)
# ------------------------------------------------------------------
# Input sanitisation
# ------------------------------------------------------------------
# Normal bash commands rarely exceed a few hundred characters. 10 000 is
# well above any legitimate use case yet a tiny fraction of Linux ARG_MAX.
# Anything longer is almost certainly a payload injection or base64-encoded
# attack string.
_MAX_COMMAND_LENGTH = 10_000
def _validate_input(self, command: str) -> str | None:
"""Return ``None`` if *command* is acceptable, else a rejection reason."""
if not command.strip():
return "empty command"
if len(command) > self._MAX_COMMAND_LENGTH:
return "command too long"
if "\x00" in command:
return "null byte detected"
return None
# ------------------------------------------------------------------
# Core logic (shared between sync and async paths)
# ------------------------------------------------------------------
def _pre_process(self, request: ToolCallRequest) -> tuple[str, str | None, str, str | None]:
"""
Returns (command, thread_id, verdict, reject_reason).
verdict is 'block', 'warn', or 'pass'.
reject_reason is non-None only for input sanitisation rejections.
"""
args = request.tool_call.get("args", {})
raw_command = args.get("command")
command = raw_command if isinstance(raw_command, str) else ""
thread_id = self._get_thread_id(request)
# ① input sanitisation — reject malformed input before regex analysis
reject_reason = self._validate_input(command)
if reject_reason:
self._write_audit(thread_id, command, "block", truncate=True)
logger.warning("[SandboxAudit] INVALID INPUT thread=%s reason=%s", thread_id, reject_reason)
return command, thread_id, "block", reject_reason
# ② classify command
verdict = _classify_command(command)
# ③ audit log
self._write_audit(thread_id, command, verdict)
if verdict == "block":
logger.warning("[SandboxAudit] BLOCKED thread=%s cmd=%r", thread_id, command)
elif verdict == "warn":
logger.warning("[SandboxAudit] WARN (medium-risk) thread=%s cmd=%r", thread_id, command)
return command, thread_id, verdict, None
# ------------------------------------------------------------------
# wrap_tool_call hooks
# ------------------------------------------------------------------
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
if request.tool_call.get("name") != "bash":
return handler(request)
command, _, verdict, reject_reason = self._pre_process(request)
if verdict == "block":
reason = reject_reason or "security violation detected"
return self._build_block_message(request, reason)
result = handler(request)
if verdict == "warn":
result = self._append_warn_to_result(result, command)
return result
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
if request.tool_call.get("name") != "bash":
return await handler(request)
command, _, verdict, reject_reason = self._pre_process(request)
if verdict == "block":
reason = reject_reason or "security violation detected"
return self._build_block_message(request, reason)
result = await handler(request)
if verdict == "warn":
result = self._append_warn_to_result(result, command)
return result

View File

@@ -0,0 +1,75 @@
"""Middleware to enforce maximum concurrent subagent tool calls per model response."""
import logging
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
logger = logging.getLogger(__name__)
# Valid range for max_concurrent_subagents
MIN_SUBAGENT_LIMIT = 2
MAX_SUBAGENT_LIMIT = 4
def _clamp_subagent_limit(value: int) -> int:
"""Clamp subagent limit to valid range [2, 4]."""
return max(MIN_SUBAGENT_LIMIT, min(MAX_SUBAGENT_LIMIT, value))
class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
"""Truncates excess 'task' tool calls from a single model response.
When an LLM generates more than max_concurrent parallel task tool calls
in one response, this middleware keeps only the first max_concurrent and
discards the rest. This is more reliable than prompt-based limits.
Args:
max_concurrent: Maximum number of concurrent subagent calls allowed.
Defaults to MAX_CONCURRENT_SUBAGENTS (3). Clamped to [2, 4].
"""
def __init__(self, max_concurrent: int = MAX_CONCURRENT_SUBAGENTS):
super().__init__()
self.max_concurrent = _clamp_subagent_limit(max_concurrent)
def _truncate_task_calls(self, state: AgentState) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
last_msg = messages[-1]
if getattr(last_msg, "type", None) != "ai":
return None
tool_calls = getattr(last_msg, "tool_calls", None)
if not tool_calls:
return None
# Count task tool calls
task_indices = [i for i, tc in enumerate(tool_calls) if tc.get("name") == "task"]
if len(task_indices) <= self.max_concurrent:
return None
# Build set of indices to drop (excess task calls beyond the limit)
indices_to_drop = set(task_indices[self.max_concurrent :])
truncated_tool_calls = [tc for i, tc in enumerate(tool_calls) if i not in indices_to_drop]
dropped_count = len(indices_to_drop)
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
return {"messages": [updated_msg]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._truncate_task_calls(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._truncate_task_calls(state)

View File

@@ -0,0 +1,99 @@
import logging
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.paths import Paths, get_paths
logger = logging.getLogger(__name__)
class ThreadDataMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
thread_data: NotRequired[ThreadDataState | None]
class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
"""Create thread data directories for each thread execution.
Creates the following directory structure:
- {base_dir}/threads/{thread_id}/user-data/workspace
- {base_dir}/threads/{thread_id}/user-data/uploads
- {base_dir}/threads/{thread_id}/user-data/outputs
Lifecycle Management:
- With lazy_init=True (default): Only compute paths, directories created on-demand
- With lazy_init=False: Eagerly create directories in before_agent()
"""
state_schema = ThreadDataMiddlewareState
def __init__(self, base_dir: str | None = None, lazy_init: bool = True):
"""Initialize the middleware.
Args:
base_dir: Base directory for thread data. Defaults to Paths resolution.
lazy_init: If True, defer directory creation until needed.
If False, create directories eagerly in before_agent().
Default is True for optimal performance.
"""
super().__init__()
self._paths = Paths(base_dir) if base_dir else get_paths()
self._lazy_init = lazy_init
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
"""Get the paths for a thread's data directories.
Args:
thread_id: The thread ID.
Returns:
Dictionary with workspace_path, uploads_path, and outputs_path.
"""
return {
"workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
}
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
"""Create the thread data directories.
Args:
thread_id: The thread ID.
Returns:
Dictionary with the created directory paths.
"""
self._paths.ensure_thread_dirs(thread_id)
return self._get_thread_paths(thread_id)
@override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
context = runtime.context or {}
thread_id = context.get("thread_id")
if thread_id is None:
config = get_config()
thread_id = config.get("configurable", {}).get("thread_id")
if thread_id is None:
raise ValueError("Thread ID is required in runtime context or config.configurable")
if self._lazy_init:
# Lazy initialization: only compute paths, don't create directories
paths = self._get_thread_paths(thread_id)
else:
# Eager initialization: create directories immediately
paths = self._create_thread_directories(thread_id)
logger.debug("Created thread data directories for thread %s", thread_id)
return {
"thread_data": {
**paths,
}
}

View File

@@ -0,0 +1,138 @@
"""Middleware for automatic thread title generation."""
import logging
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
class TitleMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
title: NotRequired[str | None]
class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
"""Automatically generate a title for the thread after the first user message."""
state_schema = TitleMiddlewareState
def _normalize_content(self, content: object) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [self._normalize_content(item) for item in content]
return "\n".join(part for part in parts if part)
if isinstance(content, dict):
text_value = content.get("text")
if isinstance(text_value, str):
return text_value
nested_content = content.get("content")
if nested_content is not None:
return self._normalize_content(nested_content)
return ""
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
"""Check if we should generate a title for this thread."""
config = get_title_config()
if not config.enabled:
return False
# Check if thread already has a title in state
if state.get("title"):
return False
# Check if this is the first turn (has at least one user message and one assistant response)
messages = state.get("messages", [])
if len(messages) < 2:
return False
# Count user and assistant messages
user_messages = [m for m in messages if m.type == "human"]
assistant_messages = [m for m in messages if m.type == "ai"]
# Generate title after first complete exchange
return len(user_messages) == 1 and len(assistant_messages) >= 1
def _build_title_prompt(self, state: TitleMiddlewareState) -> tuple[str, str]:
"""Extract user/assistant messages and build the title prompt.
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
"""
config = get_title_config()
messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
user_msg = self._normalize_content(user_msg_content)
assistant_msg = self._normalize_content(assistant_msg_content)
prompt = config.prompt_template.format(
max_words=config.max_words,
user_msg=user_msg[:500],
assistant_msg=assistant_msg[:500],
)
return prompt, user_msg
def _parse_title(self, content: object) -> str:
"""Normalize model output into a clean title string."""
config = get_title_config()
title_content = self._normalize_content(content)
title = title_content.strip().strip('"').strip("'")
return title[: config.max_chars] if len(title) > config.max_chars else title
def _fallback_title(self, user_msg: str) -> str:
config = get_title_config()
fallback_chars = min(config.max_chars, 50)
if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation"
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Generate a local fallback title without blocking on an LLM call."""
if not self._should_generate_title(state):
return None
_, user_msg = self._build_title_prompt(state)
return {"title": self._fallback_title(user_msg)}
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Generate a title asynchronously and fall back locally on failure."""
if not self._should_generate_title(state):
return None
config = get_title_config()
prompt, user_msg = self._build_title_prompt(state)
try:
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
else:
model = create_chat_model(thinking_enabled=False)
response = await model.ainvoke(prompt)
title = self._parse_title(response.content)
if title:
return {"title": title}
except Exception:
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
return {"title": self._fallback_title(user_msg)}
@override
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
return self._generate_title_result(state)
@override
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
return await self._agenerate_title_result(state)

View File

@@ -0,0 +1,100 @@
"""Middleware that extends TodoListMiddleware with context-loss detection.
When the message history is truncated (e.g., by SummarizationMiddleware), the
original `write_todos` tool call and its ToolMessage can be scrolled out of the
active context window. This middleware detects that situation and injects a
reminder message so the model still knows about the outstanding todo list.
"""
from __future__ import annotations
from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime
def _todos_in_messages(messages: list[Any]) -> bool:
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
for msg in messages:
if isinstance(msg, AIMessage) and msg.tool_calls:
for tc in msg.tool_calls:
if tc.get("name") == "write_todos":
return True
return False
def _reminder_in_messages(messages: list[Any]) -> bool:
"""Return True if a todo_reminder HumanMessage is already present in *messages*."""
for msg in messages:
if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_reminder":
return True
return False
def _format_todos(todos: list[Todo]) -> str:
"""Format a list of Todo items into a human-readable string."""
lines: list[str] = []
for todo in todos:
status = todo.get("status", "pending")
content = todo.get("content", "")
lines.append(f"- [{status}] {content}")
return "\n".join(lines)
class TodoMiddleware(TodoListMiddleware):
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
When the original `write_todos` tool call has been truncated from the message
history (e.g., after summarization), the model loses awareness of the current
todo list. This middleware detects that gap in `before_model` / `abefore_model`
and injects a reminder message so the model can continue tracking progress.
"""
@override
def before_model(
self,
state: PlanningState,
runtime: Runtime, # noqa: ARG002
) -> dict[str, Any] | None:
"""Inject a todo-list reminder when write_todos has left the context window."""
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
if not todos:
return None
messages = state.get("messages") or []
if _todos_in_messages(messages):
# write_todos is still visible in context — nothing to do.
return None
if _reminder_in_messages(messages):
# A reminder was already injected and hasn't been truncated yet.
return None
# The todo list exists in state but the original write_todos call is gone.
# Inject a reminder as a HumanMessage so the model stays aware.
formatted = _format_todos(todos)
reminder = HumanMessage(
name="todo_reminder",
content=(
"<system_reminder>\n"
"Your todo list from earlier is no longer visible in the current context window, "
"but it is still active. Here is the current state:\n\n"
f"{formatted}\n\n"
"Continue tracking and updating this todo list as you work. "
"Call `write_todos` whenever the status of any item changes.\n"
"</system_reminder>"
),
)
return {"messages": [reminder]}
@override
async def abefore_model(
self,
state: PlanningState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async version of before_model."""
return self.before_model(state, runtime)

View File

@@ -0,0 +1,37 @@
"""Middleware for logging LLM token usage."""
import logging
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
class TokenUsageMiddleware(AgentMiddleware):
"""Logs token usage from model response usage_metadata."""
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state)
def _log_usage(self, state: AgentState) -> None:
messages = state.get("messages", [])
if not messages:
return None
last = messages[-1]
usage = getattr(last, "usage_metadata", None)
if usage:
logger.info(
"LLM token usage: input=%s output=%s total=%s",
usage.get("input_tokens", "?"),
usage.get("output_tokens", "?"),
usage.get("total_tokens", "?"),
)
return None

View File

@@ -0,0 +1,143 @@
"""Tool error handling middleware and shared runtime middleware builders."""
import logging
from collections.abc import Awaitable, Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.errors import GraphBubbleUp
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__)
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"""Convert tool exceptions into error ToolMessages so the run can continue."""
def _build_error_message(self, request: ToolCallRequest, exc: Exception) -> ToolMessage:
tool_name = str(request.tool_call.get("name") or "unknown_tool")
tool_call_id = str(request.tool_call.get("id") or _MISSING_TOOL_CALL_ID)
detail = str(exc).strip() or exc.__class__.__name__
if len(detail) > 500:
detail = detail[:497] + "..."
content = f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. Continue with available context, or choose an alternative tool."
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
try:
return handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
return self._build_error_message(request, exc)
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
try:
return await handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
return self._build_error_message(request, exc)
def _build_runtime_middlewares(
*,
include_uploads: bool,
include_dangling_tool_call_patch: bool,
lazy_init: bool = True,
) -> list[AgentMiddleware]:
"""Build shared base middlewares for agent execution."""
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from deerflow.sandbox.middleware import SandboxMiddleware
middlewares: list[AgentMiddleware] = [
ThreadDataMiddleware(lazy_init=lazy_init),
SandboxMiddleware(lazy_init=lazy_init),
]
if include_uploads:
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
middlewares.insert(1, UploadsMiddleware())
if include_dangling_tool_call_patch:
from deerflow.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
middlewares.append(DanglingToolCallMiddleware())
middlewares.append(LLMErrorHandlingMiddleware())
# Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config
guardrails_config = get_guardrails_config()
if guardrails_config.enabled and guardrails_config.provider:
import inspect
from deerflow.guardrails.middleware import GuardrailMiddleware
from deerflow.reflection import resolve_variable
provider_cls = resolve_variable(guardrails_config.provider.use)
provider_kwargs = dict(guardrails_config.provider.config) if guardrails_config.provider.config else {}
# Pass framework hint if the provider accepts it (e.g. for config discovery).
# Built-in providers like AllowlistProvider don't need it, so only inject
# when the constructor accepts 'framework' or '**kwargs'.
if "framework" not in provider_kwargs:
try:
sig = inspect.signature(provider_cls.__init__)
if "framework" in sig.parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):
provider_kwargs["framework"] = "deerflow"
except (ValueError, TypeError):
pass
provider = provider_cls(**provider_kwargs)
middlewares.append(GuardrailMiddleware(provider, fail_closed=guardrails_config.fail_closed, passport=guardrails_config.passport))
from deerflow.agents.middlewares.sandbox_audit_middleware import SandboxAuditMiddleware
middlewares.append(SandboxAuditMiddleware())
middlewares.append(ToolErrorHandlingMiddleware())
return middlewares
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
return _build_runtime_middlewares(
include_uploads=True,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,
)
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
return _build_runtime_middlewares(
include_uploads=False,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,
)

View File

@@ -0,0 +1,293 @@
"""Middleware to inject uploaded files information into agent context."""
import logging
from pathlib import Path
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.config.paths import Paths, get_paths
from deerflow.utils.file_conversion import extract_outline
logger = logging.getLogger(__name__)
_OUTLINE_PREVIEW_LINES = 5
def _extract_outline_for_file(file_path: Path) -> tuple[list[dict], list[str]]:
"""Return the document outline and fallback preview for *file_path*.
Looks for a sibling ``<stem>.md`` file produced by the upload conversion
pipeline.
Returns:
(outline, preview) where:
- outline: list of ``{title, line}`` dicts (plus optional sentinel).
Empty when no headings are found or no .md exists.
- preview: first few non-empty lines of the .md, used as a content
anchor when outline is empty so the agent has some context.
Empty when outline is non-empty (no fallback needed).
"""
md_path = file_path.with_suffix(".md")
if not md_path.is_file():
return [], []
outline = extract_outline(md_path)
if outline:
logger.debug("Extracted %d outline entries from %s", len(outline), file_path.name)
return outline, []
# outline is empty — read the first few non-empty lines as a content preview
preview: list[str] = []
try:
with md_path.open(encoding="utf-8") as f:
for line in f:
stripped = line.strip()
if stripped:
preview.append(stripped)
if len(preview) >= _OUTLINE_PREVIEW_LINES:
break
except Exception:
logger.debug("Failed to read preview lines from %s", md_path, exc_info=True)
return [], preview
class UploadsMiddlewareState(AgentState):
"""State schema for uploads middleware."""
uploaded_files: NotRequired[list[dict] | None]
class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
"""Middleware to inject uploaded files information into the agent context.
Reads file metadata from the current message's additional_kwargs.files
(set by the frontend after upload) and prepends an <uploaded_files> block
to the last human message so the model knows which files are available.
"""
state_schema = UploadsMiddlewareState
def __init__(self, base_dir: str | None = None):
"""Initialize the middleware.
Args:
base_dir: Base directory for thread data. Defaults to Paths resolution.
"""
super().__init__()
self._paths = Paths(base_dir) if base_dir else get_paths()
def _format_file_entry(self, file: dict, lines: list[str]) -> None:
"""Append a single file entry (name, size, path, optional outline) to lines."""
size_kb = file["size"] / 1024
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
lines.append(f"- {file['filename']} ({size_str})")
lines.append(f" Path: {file['path']}")
outline = file.get("outline") or []
if outline:
truncated = outline[-1].get("truncated", False)
visible = [e for e in outline if not e.get("truncated")]
lines.append(" Document outline (use `read_file` with line ranges to read sections):")
for entry in visible:
lines.append(f" L{entry['line']}: {entry['title']}")
if truncated:
lines.append(f" ... (showing first {len(visible)} headings; use `read_file` to explore further)")
else:
preview = file.get("outline_preview") or []
if preview:
lines.append(" No structural headings detected. Document begins with:")
for text in preview:
lines.append(f" > {text}")
lines.append(" Use `grep` to search for keywords (e.g. `grep(pattern='keyword', path='/mnt/user-data/uploads/')`).")
lines.append("")
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
"""Create a formatted message listing uploaded files.
Args:
new_files: Files uploaded in the current message.
historical_files: Files uploaded in previous messages.
Each file dict may contain an optional ``outline`` key — a list of
``{title, line}`` dicts extracted from the converted Markdown file.
Returns:
Formatted string inside <uploaded_files> tags.
"""
lines = ["<uploaded_files>"]
lines.append("The following files were uploaded in this message:")
lines.append("")
if new_files:
for file in new_files:
self._format_file_entry(file, lines)
else:
lines.append("(empty)")
lines.append("")
if historical_files:
lines.append("The following files were uploaded in previous messages and are still available:")
lines.append("")
for file in historical_files:
self._format_file_entry(file, lines)
lines.append("To work with these files:")
lines.append("- Read from the file first — use the outline line numbers and `read_file` to locate relevant sections.")
lines.append("- Use `grep` to search for keywords when you are not sure which section to look at")
lines.append(" (e.g. `grep(pattern='revenue', path='/mnt/user-data/uploads/')`).")
lines.append("- Use `glob` to find files by name pattern")
lines.append(" (e.g. `glob(pattern='**/*.md', path='/mnt/user-data/uploads/')`).")
lines.append("- Only fall back to web search if the file content is clearly insufficient to answer the question.")
lines.append("</uploaded_files>")
return "\n".join(lines)
def _files_from_kwargs(self, message: HumanMessage, uploads_dir: Path | None = None) -> list[dict] | None:
"""Extract file info from message additional_kwargs.files.
The frontend sends uploaded file metadata in additional_kwargs.files
after a successful upload. Each entry has: filename, size (bytes),
path (virtual path), status.
Args:
message: The human message to inspect.
uploads_dir: Physical uploads directory used to verify file existence.
When provided, entries whose files no longer exist are skipped.
Returns:
List of file dicts with virtual paths, or None if the field is absent or empty.
"""
kwargs_files = (message.additional_kwargs or {}).get("files")
if not isinstance(kwargs_files, list) or not kwargs_files:
return None
files = []
for f in kwargs_files:
if not isinstance(f, dict):
continue
filename = f.get("filename") or ""
if not filename or Path(filename).name != filename:
continue
if uploads_dir is not None and not (uploads_dir / filename).is_file():
continue
files.append(
{
"filename": filename,
"size": int(f.get("size") or 0),
"path": f"/mnt/user-data/uploads/{filename}",
"extension": Path(filename).suffix,
}
)
return files if files else None
@override
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
"""Inject uploaded files information before agent execution.
New files come from the current message's additional_kwargs.files.
Historical files are scanned from the thread's uploads directory,
excluding the new ones.
Prepends <uploaded_files> context to the last human message content.
The original additional_kwargs (including files metadata) is preserved
on the updated message so the frontend can read it from the stream.
Args:
state: Current agent state.
runtime: Runtime context containing thread_id.
Returns:
State updates including uploaded files list.
"""
messages = list(state.get("messages", []))
if not messages:
return None
last_message_index = len(messages) - 1
last_message = messages[last_message_index]
if not isinstance(last_message, HumanMessage):
return None
# Resolve uploads directory for existence checks
thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
try:
from langgraph.config import get_config
thread_id = get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
pass # get_config() raises outside a runnable context (e.g. unit tests)
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
# Collect historical files from the uploads directory (all except the new ones)
new_filenames = {f["filename"] for f in new_files}
historical_files: list[dict] = []
if uploads_dir and uploads_dir.exists():
for file_path in sorted(uploads_dir.iterdir()):
if file_path.is_file() and file_path.name not in new_filenames:
stat = file_path.stat()
outline, preview = _extract_outline_for_file(file_path)
historical_files.append(
{
"filename": file_path.name,
"size": stat.st_size,
"path": f"/mnt/user-data/uploads/{file_path.name}",
"extension": file_path.suffix,
"outline": outline,
"outline_preview": preview,
}
)
# Attach outlines to new files as well
if uploads_dir:
for file in new_files:
phys_path = uploads_dir / file["filename"]
outline, preview = _extract_outline_for_file(phys_path)
file["outline"] = outline
file["outline_preview"] = preview
if not new_files and not historical_files:
return None
logger.debug(f"New files: {[f['filename'] for f in new_files]}, historical: {[f['filename'] for f in historical_files]}")
# Create files message and prepend to the last human message content
files_message = self._create_files_message(new_files, historical_files)
# Extract original content - handle both string and list formats
original_content = last_message.content
if isinstance(original_content, str):
# Simple case: string content, just prepend files message
updated_content = f"{files_message}\n\n{original_content}"
elif isinstance(original_content, list):
# Complex case: list content (multimodal), preserve all blocks
# Prepend files message as the first text block
files_block = {"type": "text", "text": f"{files_message}\n\n"}
# Keep all original blocks (including images)
updated_content = [files_block, *original_content]
else:
# Other types, preserve as-is
updated_content = original_content
# Create new message with combined content.
# Preserve additional_kwargs (including files metadata) so the frontend
# can read structured file info from the streamed message.
updated_message = HumanMessage(
content=updated_content,
id=last_message.id,
additional_kwargs=last_message.additional_kwargs,
)
messages[last_message_index] = updated_message
return {
"uploaded_files": new_files,
"messages": messages,
}

View File

@@ -0,0 +1,222 @@
"""Middleware for injecting image details into conversation before LLM call."""
import logging
from typing import override
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadState
logger = logging.getLogger(__name__)
class ViewImageMiddlewareState(ThreadState):
"""Reuse the thread state so reducer-backed keys keep their annotations."""
class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]):
"""Injects image details as a human message before LLM calls when view_image tools have completed.
This middleware:
1. Runs before each LLM call
2. Checks if the last assistant message contains view_image tool calls
3. Verifies all tool calls in that message have been completed (have corresponding ToolMessages)
4. If conditions are met, creates a human message with all viewed image details (including base64 data)
5. Adds the message to state so the LLM can see and analyze the images
This enables the LLM to automatically receive and analyze images that were loaded via view_image tool,
without requiring explicit user prompts to describe the images.
"""
state_schema = ViewImageMiddlewareState
def _get_last_assistant_message(self, messages: list) -> AIMessage | None:
"""Get the last assistant message from the message list.
Args:
messages: List of messages
Returns:
Last AIMessage or None if not found
"""
for msg in reversed(messages):
if isinstance(msg, AIMessage):
return msg
return None
def _has_view_image_tool(self, message: AIMessage) -> bool:
"""Check if the assistant message contains view_image tool calls.
Args:
message: Assistant message to check
Returns:
True if message contains view_image tool calls
"""
if not hasattr(message, "tool_calls") or not message.tool_calls:
return False
return any(tool_call.get("name") == "view_image" for tool_call in message.tool_calls)
def _all_tools_completed(self, messages: list, assistant_msg: AIMessage) -> bool:
"""Check if all tool calls in the assistant message have been completed.
Args:
messages: List of all messages
assistant_msg: The assistant message containing tool calls
Returns:
True if all tool calls have corresponding ToolMessages
"""
if not hasattr(assistant_msg, "tool_calls") or not assistant_msg.tool_calls:
return False
# Get all tool call IDs from the assistant message
tool_call_ids = {tool_call.get("id") for tool_call in assistant_msg.tool_calls if tool_call.get("id")}
# Find the index of the assistant message
try:
assistant_idx = messages.index(assistant_msg)
except ValueError:
return False
# Get all ToolMessages after the assistant message
completed_tool_ids = set()
for msg in messages[assistant_idx + 1 :]:
if isinstance(msg, ToolMessage) and msg.tool_call_id:
completed_tool_ids.add(msg.tool_call_id)
# Check if all tool calls have been completed
return tool_call_ids.issubset(completed_tool_ids)
def _create_image_details_message(self, state: ViewImageMiddlewareState) -> list[str | dict]:
"""Create a formatted message with all viewed image details.
Args:
state: Current state containing viewed_images
Returns:
List of content blocks (text and images) for the HumanMessage
"""
viewed_images = state.get("viewed_images", {})
if not viewed_images:
# Return a properly formatted text block, not a plain string array
return [{"type": "text", "text": "No images have been viewed."}]
# Build the message with image information
content_blocks: list[str | dict] = [{"type": "text", "text": "Here are the images you've viewed:"}]
for image_path, image_data in viewed_images.items():
mime_type = image_data.get("mime_type", "unknown")
base64_data = image_data.get("base64", "")
# Add text description
content_blocks.append({"type": "text", "text": f"\n- **{image_path}** ({mime_type})"})
# Add the actual image data so LLM can "see" it
if base64_data:
content_blocks.append(
{
"type": "image_url",
"image_url": {"url": f"data:{mime_type};base64,{base64_data}"},
}
)
return content_blocks
def _should_inject_image_message(self, state: ViewImageMiddlewareState) -> bool:
"""Determine if we should inject an image details message.
Args:
state: Current state
Returns:
True if we should inject the message
"""
messages = state.get("messages", [])
if not messages:
return False
# Get the last assistant message
last_assistant_msg = self._get_last_assistant_message(messages)
if not last_assistant_msg:
return False
# Check if it has view_image tool calls
if not self._has_view_image_tool(last_assistant_msg):
return False
# Check if all tools have been completed
if not self._all_tools_completed(messages, last_assistant_msg):
return False
# Check if we've already added an image details message
# Look for a human message after the last assistant message that contains image details
assistant_idx = messages.index(last_assistant_msg)
for msg in messages[assistant_idx + 1 :]:
if isinstance(msg, HumanMessage):
content_str = str(msg.content)
if "Here are the images you've viewed" in content_str or "Here are the details of the images you've viewed" in content_str:
# Already added, don't add again
return False
return True
def _inject_image_message(self, state: ViewImageMiddlewareState) -> dict | None:
"""Internal helper to inject image details message.
Args:
state: Current state
Returns:
State update with additional human message, or None if no update needed
"""
if not self._should_inject_image_message(state):
return None
# Create the image details message with text and image content
image_content = self._create_image_details_message(state)
# Create a new human message with mixed content (text + images)
human_msg = HumanMessage(content=image_content)
logger.debug("Injecting image details message with images before LLM call")
# Return state update with the new message
return {"messages": [human_msg]}
@override
def before_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
"""Inject image details message before LLM call if view_image tools have completed (sync version).
This runs before each LLM call, checking if the previous turn included view_image
tool calls that have all completed. If so, it injects a human message with the image
details so the LLM can see and analyze the images.
Args:
state: Current state
runtime: Runtime context (unused but required by interface)
Returns:
State update with additional human message, or None if no update needed
"""
return self._inject_image_message(state)
@override
async def abefore_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
"""Inject image details message before LLM call if view_image tools have completed (async version).
This runs before each LLM call, checking if the previous turn included view_image
tool calls that have all completed. If so, it injects a human message with the image
details so the LLM can see and analyze the images.
Args:
state: Current state
runtime: Runtime context (unused but required by interface)
Returns:
State update with additional human message, or None if no update needed
"""
return self._inject_image_message(state)