Initial commit: hardened DeerFlow factory
Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection hardening: - New deerflow.security package: content_delimiter, html_cleaner, sanitizer (8 layers — invisible chars, control chars, symbols, NFC, PUA, tag chars, horizontal whitespace collapse with newline/tab preservation, length cap) - New deerflow.community.searx package: web_search, web_fetch, image_search backed by a private SearX instance, every external string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>> delimiters - All native community web providers (ddg_search, tavily, exa, firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail stubs that raise NativeWebToolDisabledError at import time, so a misconfigured tool.use path fails loud rather than silently falling back to unsanitized output - Native client back-doors (jina_client.py, infoquest_client.py) stubbed too - Native-tool tests quarantined under tests/_disabled_native/ (collect_ignore_glob via local conftest.py) - Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve newlines and tabs so list/table structure survives - Hardened runtime config.yaml references only the searx-backed tools - Factory overlay (backend/) kept in sync with deer-flow tree as a reference / source See HARDENING.md for the full audit trail and verification steps.
This commit is contained in:
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
"""Middleware for intercepting clarification requests and presenting them to the user."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.graph import END
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClarificationMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ClarificationMiddleware(AgentMiddleware[ClarificationMiddlewareState]):
|
||||
"""Intercepts clarification tool calls and interrupts execution to present questions to the user.
|
||||
|
||||
When the model calls the `ask_clarification` tool, this middleware:
|
||||
1. Intercepts the tool call before execution
|
||||
2. Extracts the clarification question and metadata
|
||||
3. Formats a user-friendly message
|
||||
4. Returns a Command that interrupts execution and presents the question
|
||||
5. Waits for user response before continuing
|
||||
|
||||
This replaces the tool-based approach where clarification continued the conversation flow.
|
||||
"""
|
||||
|
||||
state_schema = ClarificationMiddlewareState
|
||||
|
||||
def _is_chinese(self, text: str) -> bool:
|
||||
"""Check if text contains Chinese characters.
|
||||
|
||||
Args:
|
||||
text: Text to check
|
||||
|
||||
Returns:
|
||||
True if text contains Chinese characters
|
||||
"""
|
||||
return any("\u4e00" <= char <= "\u9fff" for char in text)
|
||||
|
||||
def _format_clarification_message(self, args: dict) -> str:
|
||||
"""Format the clarification arguments into a user-friendly message.
|
||||
|
||||
Args:
|
||||
args: The tool call arguments containing clarification details
|
||||
|
||||
Returns:
|
||||
Formatted message string
|
||||
"""
|
||||
question = args.get("question", "")
|
||||
clarification_type = args.get("clarification_type", "missing_info")
|
||||
context = args.get("context")
|
||||
options = args.get("options", [])
|
||||
|
||||
# Some models (e.g. Qwen3-Max) serialize array parameters as JSON strings
|
||||
# instead of native arrays. Deserialize and normalize so `options`
|
||||
# is always a list for the rendering logic below.
|
||||
if isinstance(options, str):
|
||||
try:
|
||||
options = json.loads(options)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
options = [options]
|
||||
|
||||
if options is None:
|
||||
options = []
|
||||
elif not isinstance(options, list):
|
||||
options = [options]
|
||||
|
||||
# Type-specific icons
|
||||
type_icons = {
|
||||
"missing_info": "❓",
|
||||
"ambiguous_requirement": "🤔",
|
||||
"approach_choice": "🔀",
|
||||
"risk_confirmation": "⚠️",
|
||||
"suggestion": "💡",
|
||||
}
|
||||
|
||||
icon = type_icons.get(clarification_type, "❓")
|
||||
|
||||
# Build the message naturally
|
||||
message_parts = []
|
||||
|
||||
# Add icon and question together for a more natural flow
|
||||
if context:
|
||||
# If there's context, present it first as background
|
||||
message_parts.append(f"{icon} {context}")
|
||||
message_parts.append(f"\n{question}")
|
||||
else:
|
||||
# Just the question with icon
|
||||
message_parts.append(f"{icon} {question}")
|
||||
|
||||
# Add options in a cleaner format
|
||||
if options and len(options) > 0:
|
||||
message_parts.append("") # blank line for spacing
|
||||
for i, option in enumerate(options, 1):
|
||||
message_parts.append(f" {i}. {option}")
|
||||
|
||||
return "\n".join(message_parts)
|
||||
|
||||
def _handle_clarification(self, request: ToolCallRequest) -> Command:
|
||||
"""Handle clarification request and return command to interrupt execution.
|
||||
|
||||
Args:
|
||||
request: Tool call request
|
||||
|
||||
Returns:
|
||||
Command that interrupts execution with the formatted clarification message
|
||||
"""
|
||||
# Extract clarification arguments
|
||||
args = request.tool_call.get("args", {})
|
||||
question = args.get("question", "")
|
||||
|
||||
logger.info("Intercepted clarification request")
|
||||
logger.debug("Clarification question: %s", question)
|
||||
|
||||
# Format the clarification message
|
||||
formatted_message = self._format_clarification_message(args)
|
||||
|
||||
# Get the tool call ID
|
||||
tool_call_id = request.tool_call.get("id", "")
|
||||
|
||||
# Create a ToolMessage with the formatted question
|
||||
# This will be added to the message history
|
||||
tool_message = ToolMessage(
|
||||
content=formatted_message,
|
||||
tool_call_id=tool_call_id,
|
||||
name="ask_clarification",
|
||||
)
|
||||
|
||||
# Return a Command that:
|
||||
# 1. Adds the formatted tool message
|
||||
# 2. Interrupts execution by going to __end__
|
||||
# Note: We don't add an extra AIMessage here - the frontend will detect
|
||||
# and display ask_clarification tool messages directly
|
||||
return Command(
|
||||
update={"messages": [tool_message]},
|
||||
goto=END,
|
||||
)
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept ask_clarification tool calls and interrupt execution (sync version).
|
||||
|
||||
Args:
|
||||
request: Tool call request
|
||||
handler: Original tool execution handler
|
||||
|
||||
Returns:
|
||||
Command that interrupts execution with the formatted clarification message
|
||||
"""
|
||||
# Check if this is an ask_clarification tool call
|
||||
if request.tool_call.get("name") != "ask_clarification":
|
||||
# Not a clarification call, execute normally
|
||||
return handler(request)
|
||||
|
||||
return self._handle_clarification(request)
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept ask_clarification tool calls and interrupt execution (async version).
|
||||
|
||||
Args:
|
||||
request: Tool call request
|
||||
handler: Original tool execution handler (async)
|
||||
|
||||
Returns:
|
||||
Command that interrupts execution with the formatted clarification message
|
||||
"""
|
||||
# Check if this is an ask_clarification tool call
|
||||
if request.tool_call.get("name") != "ask_clarification":
|
||||
# Not a clarification call, execute normally
|
||||
return await handler(request)
|
||||
|
||||
return self._handle_clarification(request)
|
||||
@@ -0,0 +1,110 @@
|
||||
"""Middleware to fix dangling tool calls in message history.
|
||||
|
||||
A dangling tool call occurs when an AIMessage contains tool_calls but there are
|
||||
no corresponding ToolMessages in the history (e.g., due to user interruption or
|
||||
request cancellation). This causes LLM errors due to incomplete message format.
|
||||
|
||||
This middleware intercepts the model call to detect and patch such gaps by
|
||||
inserting synthetic ToolMessages with an error indicator immediately after the
|
||||
AIMessage that made the tool calls, ensuring correct message ordering.
|
||||
|
||||
Note: Uses wrap_model_call instead of before_model to ensure patches are inserted
|
||||
at the correct positions (immediately after each dangling AIMessage), not appended
|
||||
to the end of the message list as before_model + add_messages reducer would do.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
|
||||
|
||||
Scans the message history for AIMessages whose tool_calls lack corresponding
|
||||
ToolMessages, and injects synthetic error responses immediately after the
|
||||
offending AIMessage so the LLM receives a well-formed conversation.
|
||||
"""
|
||||
|
||||
def _build_patched_messages(self, messages: list) -> list | None:
|
||||
"""Return a new message list with patches inserted at the correct positions.
|
||||
|
||||
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
|
||||
a synthetic ToolMessage is inserted immediately after that AIMessage.
|
||||
Returns None if no patches are needed.
|
||||
"""
|
||||
# Collect IDs of all existing ToolMessages
|
||||
existing_tool_msg_ids: set[str] = set()
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
existing_tool_msg_ids.add(msg.tool_call_id)
|
||||
|
||||
# Check if any patching is needed
|
||||
needs_patch = False
|
||||
for msg in messages:
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
for tc in getattr(msg, "tool_calls", None) or []:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids:
|
||||
needs_patch = True
|
||||
break
|
||||
if needs_patch:
|
||||
break
|
||||
|
||||
if not needs_patch:
|
||||
return None
|
||||
|
||||
# Build new list with patches inserted right after each dangling AIMessage
|
||||
patched: list = []
|
||||
patched_ids: set[str] = set()
|
||||
patch_count = 0
|
||||
for msg in messages:
|
||||
patched.append(msg)
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
for tc in getattr(msg, "tool_calls", None) or []:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
||||
patched.append(
|
||||
ToolMessage(
|
||||
content="[Tool call was interrupted and did not return a result.]",
|
||||
tool_call_id=tc_id,
|
||||
name=tc.get("name", "unknown"),
|
||||
status="error",
|
||||
)
|
||||
)
|
||||
patched_ids.add(tc_id)
|
||||
patch_count += 1
|
||||
|
||||
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||
return patched
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
patched = self._build_patched_messages(request.messages)
|
||||
if patched is not None:
|
||||
request = request.override(messages=patched)
|
||||
return handler(request)
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
patched = self._build_patched_messages(request.messages)
|
||||
if patched is not None:
|
||||
request = request.override(messages=patched)
|
||||
return await handler(request)
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Middleware to filter deferred tool schemas from model binding.
|
||||
|
||||
When tool_search is enabled, MCP tools are registered in the DeferredToolRegistry
|
||||
and passed to ToolNode for execution, but their schemas should NOT be sent to the
|
||||
LLM via bind_tools (that's the whole point of deferral — saving context tokens).
|
||||
|
||||
This middleware intercepts wrap_model_call and removes deferred tools from
|
||||
request.tools so that model.bind_tools only receives active tool schemas.
|
||||
The agent discovers deferred tools at runtime via the tool_search tool.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Remove deferred tools from request.tools before model binding.
|
||||
|
||||
ToolNode still holds all tools (including deferred) for execution routing,
|
||||
but the LLM only sees active tool schemas — deferred tools are discoverable
|
||||
via tool_search at runtime.
|
||||
"""
|
||||
|
||||
def _filter_tools(self, request: ModelRequest) -> ModelRequest:
|
||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||
|
||||
registry = get_deferred_registry()
|
||||
if not registry:
|
||||
return request
|
||||
|
||||
deferred_names = {e.name for e in registry.entries}
|
||||
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
|
||||
|
||||
if len(active_tools) < len(request.tools):
|
||||
logger.debug(f"Filtered {len(request.tools) - len(active_tools)} deferred tool schema(s) from model binding")
|
||||
|
||||
return request.override(tools=active_tools)
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(self._filter_tools(request))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
return await handler(self._filter_tools(request))
|
||||
@@ -0,0 +1,275 @@
|
||||
"""LLM error handling middleware with retry/backoff and user-facing fallbacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_RETRIABLE_STATUS_CODES = {408, 409, 425, 429, 500, 502, 503, 504}
|
||||
_BUSY_PATTERNS = (
|
||||
"server busy",
|
||||
"temporarily unavailable",
|
||||
"try again later",
|
||||
"please retry",
|
||||
"please try again",
|
||||
"overloaded",
|
||||
"high demand",
|
||||
"rate limit",
|
||||
"负载较高",
|
||||
"服务繁忙",
|
||||
"稍后重试",
|
||||
"请稍后重试",
|
||||
)
|
||||
_QUOTA_PATTERNS = (
|
||||
"insufficient_quota",
|
||||
"quota",
|
||||
"billing",
|
||||
"credit",
|
||||
"payment",
|
||||
"余额不足",
|
||||
"超出限额",
|
||||
"额度不足",
|
||||
"欠费",
|
||||
)
|
||||
_AUTH_PATTERNS = (
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"invalid api key",
|
||||
"invalid_api_key",
|
||||
"permission",
|
||||
"forbidden",
|
||||
"access denied",
|
||||
"无权",
|
||||
"未授权",
|
||||
)
|
||||
|
||||
|
||||
class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Retry transient LLM errors and surface graceful assistant messages."""
|
||||
|
||||
retry_max_attempts: int = 3
|
||||
retry_base_delay_ms: int = 1000
|
||||
retry_cap_delay_ms: int = 8000
|
||||
|
||||
def _classify_error(self, exc: BaseException) -> tuple[bool, str]:
|
||||
detail = _extract_error_detail(exc)
|
||||
lowered = detail.lower()
|
||||
error_code = _extract_error_code(exc)
|
||||
status_code = _extract_status_code(exc)
|
||||
|
||||
if _matches_any(lowered, _QUOTA_PATTERNS) or _matches_any(str(error_code).lower(), _QUOTA_PATTERNS):
|
||||
return False, "quota"
|
||||
if _matches_any(lowered, _AUTH_PATTERNS):
|
||||
return False, "auth"
|
||||
|
||||
exc_name = exc.__class__.__name__
|
||||
if exc_name in {
|
||||
"APITimeoutError",
|
||||
"APIConnectionError",
|
||||
"InternalServerError",
|
||||
}:
|
||||
return True, "transient"
|
||||
if status_code in _RETRIABLE_STATUS_CODES:
|
||||
return True, "transient"
|
||||
if _matches_any(lowered, _BUSY_PATTERNS):
|
||||
return True, "busy"
|
||||
|
||||
return False, "generic"
|
||||
|
||||
def _build_retry_delay_ms(self, attempt: int, exc: BaseException) -> int:
|
||||
retry_after = _extract_retry_after_ms(exc)
|
||||
if retry_after is not None:
|
||||
return retry_after
|
||||
backoff = self.retry_base_delay_ms * (2 ** max(0, attempt - 1))
|
||||
return min(backoff, self.retry_cap_delay_ms)
|
||||
|
||||
def _build_retry_message(self, attempt: int, wait_ms: int, reason: str) -> str:
|
||||
seconds = max(1, round(wait_ms / 1000))
|
||||
reason_text = "provider is busy" if reason == "busy" else "provider request failed temporarily"
|
||||
return f"LLM request retry {attempt}/{self.retry_max_attempts}: {reason_text}. Retrying in {seconds}s."
|
||||
|
||||
def _build_user_message(self, exc: BaseException, reason: str) -> str:
|
||||
detail = _extract_error_detail(exc)
|
||||
if reason == "quota":
|
||||
return "The configured LLM provider rejected the request because the account is out of quota, billing is unavailable, or usage is restricted. Please fix the provider account and try again."
|
||||
if reason == "auth":
|
||||
return "The configured LLM provider rejected the request because authentication or access is invalid. Please check the provider credentials and try again."
|
||||
if reason in {"busy", "transient"}:
|
||||
return "The configured LLM provider is temporarily unavailable after multiple retries. Please wait a moment and continue the conversation."
|
||||
return f"LLM request failed: {detail}"
|
||||
|
||||
def _emit_retry_event(self, attempt: int, wait_ms: int, reason: str) -> None:
|
||||
try:
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
writer = get_stream_writer()
|
||||
writer(
|
||||
{
|
||||
"type": "llm_retry",
|
||||
"attempt": attempt,
|
||||
"max_attempts": self.retry_max_attempts,
|
||||
"wait_ms": wait_ms,
|
||||
"reason": reason,
|
||||
"message": self._build_retry_message(attempt, wait_ms, reason),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to emit llm_retry event", exc_info=True)
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
if retriable and attempt < self.retry_max_attempts:
|
||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||
logger.warning(
|
||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||
attempt,
|
||||
self.retry_max_attempts,
|
||||
wait_ms,
|
||||
_extract_error_detail(exc),
|
||||
)
|
||||
self._emit_retry_event(attempt, wait_ms, reason)
|
||||
time.sleep(wait_ms / 1000)
|
||||
attempt += 1
|
||||
continue
|
||||
logger.warning(
|
||||
"LLM call failed after %d attempt(s): %s",
|
||||
attempt,
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
attempt = 1
|
||||
while True:
|
||||
try:
|
||||
return await handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
retriable, reason = self._classify_error(exc)
|
||||
if retriable and attempt < self.retry_max_attempts:
|
||||
wait_ms = self._build_retry_delay_ms(attempt, exc)
|
||||
logger.warning(
|
||||
"Transient LLM error on attempt %d/%d; retrying in %dms: %s",
|
||||
attempt,
|
||||
self.retry_max_attempts,
|
||||
wait_ms,
|
||||
_extract_error_detail(exc),
|
||||
)
|
||||
self._emit_retry_event(attempt, wait_ms, reason)
|
||||
await asyncio.sleep(wait_ms / 1000)
|
||||
attempt += 1
|
||||
continue
|
||||
logger.warning(
|
||||
"LLM call failed after %d attempt(s): %s",
|
||||
attempt,
|
||||
_extract_error_detail(exc),
|
||||
exc_info=exc,
|
||||
)
|
||||
return AIMessage(content=self._build_user_message(exc, reason))
|
||||
|
||||
|
||||
def _matches_any(detail: str, patterns: tuple[str, ...]) -> bool:
|
||||
return any(pattern in detail for pattern in patterns)
|
||||
|
||||
|
||||
def _extract_error_code(exc: BaseException) -> Any:
|
||||
for attr in ("code", "error_code"):
|
||||
value = getattr(exc, attr, None)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
|
||||
body = getattr(exc, "body", None)
|
||||
if isinstance(body, dict):
|
||||
error = body.get("error")
|
||||
if isinstance(error, dict):
|
||||
for key in ("code", "type"):
|
||||
value = error.get(key)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _extract_status_code(exc: BaseException) -> int | None:
|
||||
for attr in ("status_code", "status"):
|
||||
value = getattr(exc, attr, None)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
response = getattr(exc, "response", None)
|
||||
status = getattr(response, "status_code", None)
|
||||
return status if isinstance(status, int) else None
|
||||
|
||||
|
||||
def _extract_retry_after_ms(exc: BaseException) -> int | None:
|
||||
response = getattr(exc, "response", None)
|
||||
headers = getattr(response, "headers", None)
|
||||
if headers is None:
|
||||
return None
|
||||
|
||||
raw = None
|
||||
header_name = ""
|
||||
for key in ("retry-after-ms", "Retry-After-Ms", "retry-after", "Retry-After"):
|
||||
header_name = key
|
||||
if hasattr(headers, "get"):
|
||||
raw = headers.get(key)
|
||||
if raw:
|
||||
break
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
try:
|
||||
multiplier = 1 if "ms" in header_name.lower() else 1000
|
||||
return max(0, int(float(raw) * multiplier))
|
||||
except (TypeError, ValueError):
|
||||
try:
|
||||
target = parsedate_to_datetime(str(raw))
|
||||
delta = target.timestamp() - time.time()
|
||||
return max(0, int(delta * 1000))
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
return None
|
||||
|
||||
|
||||
def _extract_error_detail(exc: BaseException) -> str:
|
||||
detail = str(exc).strip()
|
||||
if detail:
|
||||
return detail
|
||||
message = getattr(exc, "message", None)
|
||||
if isinstance(message, str) and message.strip():
|
||||
return message.strip()
|
||||
return exc.__class__.__name__
|
||||
@@ -0,0 +1,372 @@
|
||||
"""Middleware to detect and break repetitive tool call loops.
|
||||
|
||||
P0 safety: prevents the agent from calling the same tool with the same
|
||||
arguments indefinitely until the recursion limit kills the run.
|
||||
|
||||
Detection strategy:
|
||||
1. After each model response, hash the tool calls (name + args).
|
||||
2. Track recent hashes in a sliding window.
|
||||
3. If the same hash appears >= warn_threshold times, inject a
|
||||
"you are repeating yourself — wrap up" system message (once per hash).
|
||||
4. If it appears >= hard_limit times, strip all tool_calls from the
|
||||
response so the agent is forced to produce a final text answer.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defaults — can be overridden via constructor
|
||||
_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
|
||||
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
|
||||
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
||||
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
||||
_DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type
|
||||
_DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type
|
||||
|
||||
|
||||
def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]:
|
||||
"""Normalize tool call args to a dict plus an optional fallback key.
|
||||
|
||||
Some providers serialize ``args`` as a JSON string instead of a dict.
|
||||
We defensively parse those cases so loop detection does not crash while
|
||||
still preserving a stable fallback key for non-dict payloads.
|
||||
"""
|
||||
if isinstance(raw_args, dict):
|
||||
return raw_args, None
|
||||
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
parsed = json.loads(raw_args)
|
||||
except (TypeError, ValueError, json.JSONDecodeError):
|
||||
return {}, raw_args
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
return parsed, None
|
||||
return {}, json.dumps(parsed, sort_keys=True, default=str)
|
||||
|
||||
if raw_args is None:
|
||||
return {}, None
|
||||
|
||||
return {}, json.dumps(raw_args, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _stable_tool_key(name: str, args: dict, fallback_key: str | None) -> str:
|
||||
"""Derive a stable key from salient args without overfitting to noise."""
|
||||
if name == "read_file" and fallback_key is None:
|
||||
path = args.get("path") or ""
|
||||
start_line = args.get("start_line")
|
||||
end_line = args.get("end_line")
|
||||
|
||||
bucket_size = 200
|
||||
try:
|
||||
start_line = int(start_line) if start_line is not None else 1
|
||||
except (TypeError, ValueError):
|
||||
start_line = 1
|
||||
try:
|
||||
end_line = int(end_line) if end_line is not None else start_line
|
||||
except (TypeError, ValueError):
|
||||
end_line = start_line
|
||||
|
||||
start_line, end_line = sorted((start_line, end_line))
|
||||
bucket_start = max(start_line, 1)
|
||||
bucket_end = max(end_line, 1)
|
||||
bucket_start = (bucket_start - 1) // bucket_size
|
||||
bucket_end = (bucket_end - 1) // bucket_size
|
||||
return f"{path}:{bucket_start}-{bucket_end}"
|
||||
|
||||
# write_file / str_replace are content-sensitive: same path may be updated
|
||||
# with different payloads during iteration. Using only salient fields (path)
|
||||
# can collapse distinct calls, so we hash full args to reduce false positives.
|
||||
if name in {"write_file", "str_replace"}:
|
||||
if fallback_key is not None:
|
||||
return fallback_key
|
||||
return json.dumps(args, sort_keys=True, default=str)
|
||||
|
||||
salient_fields = ("path", "url", "query", "command", "pattern", "glob", "cmd")
|
||||
stable_args = {field: args[field] for field in salient_fields if args.get(field) is not None}
|
||||
if stable_args:
|
||||
return json.dumps(stable_args, sort_keys=True, default=str)
|
||||
|
||||
if fallback_key is not None:
|
||||
return fallback_key
|
||||
|
||||
return json.dumps(args, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _hash_tool_calls(tool_calls: list[dict]) -> str:
|
||||
"""Deterministic hash of a set of tool calls (name + stable key).
|
||||
|
||||
This is intended to be order-independent: the same multiset of tool calls
|
||||
should always produce the same hash, regardless of their input order.
|
||||
"""
|
||||
# Normalize each tool call to a stable (name, key) structure.
|
||||
normalized: list[str] = []
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
args, fallback_key = _normalize_tool_call_args(tc.get("args", {}))
|
||||
key = _stable_tool_key(name, args, fallback_key)
|
||||
|
||||
normalized.append(f"{name}:{key}")
|
||||
|
||||
# Sort so permutations of the same multiset of calls yield the same ordering.
|
||||
normalized.sort()
|
||||
blob = json.dumps(normalized, sort_keys=True, default=str)
|
||||
return hashlib.md5(blob.encode()).hexdigest()[:12]
|
||||
|
||||
|
||||
_WARNING_MSG = "[LOOP DETECTED] You are repeating the same tool calls. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
||||
|
||||
_TOOL_FREQ_WARNING_MSG = (
|
||||
"[LOOP DETECTED] You have called {tool_name} {count} times without producing a final answer. Stop calling tools and produce your final answer now. If you cannot complete the task, summarize what you accomplished so far."
|
||||
)
|
||||
|
||||
_HARD_STOP_MSG = "[FORCED STOP] Repeated tool calls exceeded the safety limit. Producing final answer with results collected so far."
|
||||
|
||||
_TOOL_FREQ_HARD_STOP_MSG = "[FORCED STOP] Tool {tool_name} called {count} times — exceeded the per-tool safety limit. Producing final answer with results collected so far."
|
||||
|
||||
|
||||
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Detects and breaks repetitive tool call loops.
|
||||
|
||||
Args:
|
||||
warn_threshold: Number of identical tool call sets before injecting
|
||||
a warning message. Default: 3.
|
||||
hard_limit: Number of identical tool call sets before stripping
|
||||
tool_calls entirely. Default: 5.
|
||||
window_size: Size of the sliding window for tracking calls.
|
||||
Default: 20.
|
||||
max_tracked_threads: Maximum number of threads to track before
|
||||
evicting the least recently used. Default: 100.
|
||||
tool_freq_warn: Number of calls to the same tool *type* (regardless
|
||||
of arguments) before injecting a frequency warning. Catches
|
||||
cross-file read loops that hash-based detection misses.
|
||||
Default: 30.
|
||||
tool_freq_hard_limit: Number of calls to the same tool type before
|
||||
forcing a stop. Default: 50.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
warn_threshold: int = _DEFAULT_WARN_THRESHOLD,
|
||||
hard_limit: int = _DEFAULT_HARD_LIMIT,
|
||||
window_size: int = _DEFAULT_WINDOW_SIZE,
|
||||
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
|
||||
tool_freq_warn: int = _DEFAULT_TOOL_FREQ_WARN,
|
||||
tool_freq_hard_limit: int = _DEFAULT_TOOL_FREQ_HARD_LIMIT,
|
||||
):
|
||||
super().__init__()
|
||||
self.warn_threshold = warn_threshold
|
||||
self.hard_limit = hard_limit
|
||||
self.window_size = window_size
|
||||
self.max_tracked_threads = max_tracked_threads
|
||||
self.tool_freq_warn = tool_freq_warn
|
||||
self.tool_freq_hard_limit = tool_freq_hard_limit
|
||||
self._lock = threading.Lock()
|
||||
# Per-thread tracking using OrderedDict for LRU eviction
|
||||
self._history: OrderedDict[str, list[str]] = OrderedDict()
|
||||
self._warned: dict[str, set[str]] = defaultdict(set)
|
||||
# Per-thread, per-tool-type cumulative call counts
|
||||
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
def _get_thread_id(self, runtime: Runtime) -> str:
|
||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id:
|
||||
return thread_id
|
||||
return "default"
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict least recently used threads if over the limit.
|
||||
|
||||
Must be called while holding self._lock.
|
||||
"""
|
||||
while len(self._history) > self.max_tracked_threads:
|
||||
evicted_id, _ = self._history.popitem(last=False)
|
||||
self._warned.pop(evicted_id, None)
|
||||
self._tool_freq.pop(evicted_id, None)
|
||||
self._tool_freq_warned.pop(evicted_id, None)
|
||||
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
|
||||
|
||||
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
|
||||
"""Track tool calls and check for loops.
|
||||
|
||||
Two detection layers:
|
||||
1. **Hash-based** (existing): catches identical tool call sets.
|
||||
2. **Frequency-based** (new): catches the same *tool type* being
|
||||
called many times with varying arguments (e.g. ``read_file``
|
||||
on 40 different files).
|
||||
|
||||
Returns:
|
||||
(warning_message_or_none, should_hard_stop)
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None, False
|
||||
|
||||
last_msg = messages[-1]
|
||||
if getattr(last_msg, "type", None) != "ai":
|
||||
return None, False
|
||||
|
||||
tool_calls = getattr(last_msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
return None, False
|
||||
|
||||
thread_id = self._get_thread_id(runtime)
|
||||
call_hash = _hash_tool_calls(tool_calls)
|
||||
|
||||
with self._lock:
|
||||
# Touch / create entry (move to end for LRU)
|
||||
if thread_id in self._history:
|
||||
self._history.move_to_end(thread_id)
|
||||
else:
|
||||
self._history[thread_id] = []
|
||||
self._evict_if_needed()
|
||||
|
||||
history = self._history[thread_id]
|
||||
history.append(call_hash)
|
||||
if len(history) > self.window_size:
|
||||
history[:] = history[-self.window_size :]
|
||||
|
||||
count = history.count(call_hash)
|
||||
tool_names = [tc.get("name", "?") for tc in tool_calls]
|
||||
|
||||
# --- Layer 1: hash-based (identical call sets) ---
|
||||
if count >= self.hard_limit:
|
||||
logger.error(
|
||||
"Loop hard limit reached — forcing stop",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"call_hash": call_hash,
|
||||
"count": count,
|
||||
"tools": tool_names,
|
||||
},
|
||||
)
|
||||
return _HARD_STOP_MSG, True
|
||||
|
||||
if count >= self.warn_threshold:
|
||||
warned = self._warned[thread_id]
|
||||
if call_hash not in warned:
|
||||
warned.add(call_hash)
|
||||
logger.warning(
|
||||
"Repetitive tool calls detected — injecting warning",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"call_hash": call_hash,
|
||||
"count": count,
|
||||
"tools": tool_names,
|
||||
},
|
||||
)
|
||||
return _WARNING_MSG, False
|
||||
|
||||
# --- Layer 2: per-tool-type frequency ---
|
||||
freq = self._tool_freq[thread_id]
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
if not name:
|
||||
continue
|
||||
freq[name] += 1
|
||||
tc_count = freq[name]
|
||||
|
||||
if tc_count >= self.tool_freq_hard_limit:
|
||||
logger.error(
|
||||
"Tool frequency hard limit reached — forcing stop",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"tool_name": name,
|
||||
"count": tc_count,
|
||||
},
|
||||
)
|
||||
return _TOOL_FREQ_HARD_STOP_MSG.format(tool_name=name, count=tc_count), True
|
||||
|
||||
if tc_count >= self.tool_freq_warn:
|
||||
warned = self._tool_freq_warned[thread_id]
|
||||
if name not in warned:
|
||||
warned.add(name)
|
||||
logger.warning(
|
||||
"Tool frequency warning — too many calls to same tool type",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"tool_name": name,
|
||||
"count": tc_count,
|
||||
},
|
||||
)
|
||||
return _TOOL_FREQ_WARNING_MSG.format(tool_name=name, count=tc_count), False
|
||||
|
||||
return None, False
|
||||
|
||||
@staticmethod
|
||||
def _append_text(content: str | list | None, text: str) -> str | list:
|
||||
"""Append *text* to AIMessage content, handling str, list, and None.
|
||||
|
||||
When content is a list of content blocks (e.g. Anthropic thinking mode),
|
||||
we append a new ``{"type": "text", ...}`` block instead of concatenating
|
||||
a string to a list, which would raise ``TypeError``.
|
||||
"""
|
||||
if content is None:
|
||||
return text
|
||||
if isinstance(content, list):
|
||||
return [*content, {"type": "text", "text": f"\n\n{text}"}]
|
||||
if isinstance(content, str):
|
||||
return content + f"\n\n{text}"
|
||||
# Fallback: coerce unexpected types to str to avoid TypeError
|
||||
return str(content) + f"\n\n{text}"
|
||||
|
||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
warning, hard_stop = self._track_and_check(state, runtime)
|
||||
|
||||
if hard_stop:
|
||||
# Strip tool_calls from the last AIMessage to force text output
|
||||
messages = state.get("messages", [])
|
||||
last_msg = messages[-1]
|
||||
stripped_msg = last_msg.model_copy(
|
||||
update={
|
||||
"tool_calls": [],
|
||||
"content": self._append_text(last_msg.content, warning),
|
||||
}
|
||||
)
|
||||
return {"messages": [stripped_msg]}
|
||||
|
||||
if warning:
|
||||
# Inject as HumanMessage instead of SystemMessage to avoid
|
||||
# Anthropic's "multiple non-consecutive system messages" error.
|
||||
# Anthropic models require system messages only at the start of
|
||||
# the conversation; injecting one mid-conversation crashes
|
||||
# langchain_anthropic's _format_messages(). HumanMessage works
|
||||
# with all providers. See #1299.
|
||||
return {"messages": [HumanMessage(content=warning)]}
|
||||
|
||||
return None
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
def reset(self, thread_id: str | None = None) -> None:
|
||||
"""Clear tracking state. If thread_id given, clear only that thread."""
|
||||
with self._lock:
|
||||
if thread_id:
|
||||
self._history.pop(thread_id, None)
|
||||
self._warned.pop(thread_id, None)
|
||||
self._tool_freq.pop(thread_id, None)
|
||||
self._tool_freq_warned.pop(thread_id, None)
|
||||
else:
|
||||
self._history.clear()
|
||||
self._warned.clear()
|
||||
self._tool_freq.clear()
|
||||
self._tool_freq_warned.clear()
|
||||
@@ -0,0 +1,248 @@
|
||||
"""Middleware for memory mechanism."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UPLOAD_BLOCK_RE = re.compile(r"<uploaded_files>[\s\S]*?</uploaded_files>\n*", re.IGNORECASE)
|
||||
_CORRECTION_PATTERNS = (
|
||||
re.compile(r"\bthat(?:'s| is) (?:wrong|incorrect)\b", re.IGNORECASE),
|
||||
re.compile(r"\byou misunderstood\b", re.IGNORECASE),
|
||||
re.compile(r"\btry again\b", re.IGNORECASE),
|
||||
re.compile(r"\bredo\b", re.IGNORECASE),
|
||||
re.compile(r"不对"),
|
||||
re.compile(r"你理解错了"),
|
||||
re.compile(r"你理解有误"),
|
||||
re.compile(r"重试"),
|
||||
re.compile(r"重新来"),
|
||||
re.compile(r"换一种"),
|
||||
re.compile(r"改用"),
|
||||
)
|
||||
|
||||
_REINFORCEMENT_PATTERNS = (
|
||||
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
|
||||
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
|
||||
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
|
||||
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
|
||||
re.compile(r"完全正确(?:[。!?!?.]|$)"),
|
||||
re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
|
||||
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
|
||||
re.compile(r"继续保持(?:[。!?!?.]|$)"),
|
||||
)
|
||||
|
||||
|
||||
class MemoryMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _extract_message_text(message: Any) -> str:
|
||||
"""Extract plain text from message content for filtering and signal detection."""
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
text_parts.append(part)
|
||||
elif isinstance(part, dict):
|
||||
text_val = part.get("text")
|
||||
if isinstance(text_val, str):
|
||||
text_parts.append(text_val)
|
||||
return " ".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _filter_messages_for_memory(messages: list[Any]) -> list[Any]:
|
||||
"""Filter messages to keep only user inputs and final assistant responses.
|
||||
|
||||
This filters out:
|
||||
- Tool messages (intermediate tool call results)
|
||||
- AI messages with tool_calls (intermediate steps, not final responses)
|
||||
- The <uploaded_files> block injected by UploadsMiddleware into human messages
|
||||
(file paths are session-scoped and must not persist in long-term memory).
|
||||
The user's actual question is preserved; only turns whose content is entirely
|
||||
the upload block (nothing remains after stripping) are dropped along with
|
||||
their paired assistant response.
|
||||
|
||||
Only keeps:
|
||||
- Human messages (with the ephemeral upload block removed)
|
||||
- AI messages without tool_calls (final assistant responses), unless the
|
||||
paired human turn was upload-only and had no real user text.
|
||||
|
||||
Args:
|
||||
messages: List of all conversation messages.
|
||||
|
||||
Returns:
|
||||
Filtered list containing only user inputs and final assistant responses.
|
||||
"""
|
||||
filtered = []
|
||||
skip_next_ai = False
|
||||
for msg in messages:
|
||||
msg_type = getattr(msg, "type", None)
|
||||
|
||||
if msg_type == "human":
|
||||
content_str = _extract_message_text(msg)
|
||||
if "<uploaded_files>" in content_str:
|
||||
# Strip the ephemeral upload block; keep the user's real question.
|
||||
stripped = _UPLOAD_BLOCK_RE.sub("", content_str).strip()
|
||||
if not stripped:
|
||||
# Nothing left — the entire turn was upload bookkeeping;
|
||||
# skip it and the paired assistant response.
|
||||
skip_next_ai = True
|
||||
continue
|
||||
# Rebuild the message with cleaned content so the user's question
|
||||
# is still available for memory summarisation.
|
||||
from copy import copy
|
||||
|
||||
clean_msg = copy(msg)
|
||||
clean_msg.content = stripped
|
||||
filtered.append(clean_msg)
|
||||
skip_next_ai = False
|
||||
else:
|
||||
filtered.append(msg)
|
||||
skip_next_ai = False
|
||||
elif msg_type == "ai":
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
if skip_next_ai:
|
||||
skip_next_ai = False
|
||||
continue
|
||||
filtered.append(msg)
|
||||
# Skip tool messages and AI messages with tool_calls
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def detect_correction(messages: list[Any]) -> bool:
|
||||
"""Detect explicit user corrections in recent conversation turns.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale corrections from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _CORRECTION_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def detect_reinforcement(messages: list[Any]) -> bool:
|
||||
"""Detect explicit positive reinforcement signals in recent conversation turns.
|
||||
|
||||
Complements detect_correction() by identifying when the user confirms the
|
||||
agent's approach was correct. This allows the memory system to record what
|
||||
worked well, not just what went wrong.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale signals from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
"""Middleware that queues conversation for memory update after agent execution.
|
||||
|
||||
This middleware:
|
||||
1. After each agent execution, queues the conversation for memory update
|
||||
2. Only includes user inputs and final assistant responses (ignores tool calls)
|
||||
3. The queue uses debouncing to batch multiple updates together
|
||||
4. Memory is updated asynchronously via LLM summarization
|
||||
"""
|
||||
|
||||
state_schema = MemoryMiddlewareState
|
||||
|
||||
def __init__(self, agent_name: str | None = None):
|
||||
"""Initialize the MemoryMiddleware.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
"""
|
||||
super().__init__()
|
||||
self._agent_name = agent_name
|
||||
|
||||
@override
|
||||
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Queue conversation for memory update after agent completes.
|
||||
|
||||
Args:
|
||||
state: The current agent state.
|
||||
runtime: The runtime context.
|
||||
|
||||
Returns:
|
||||
None (no state changes needed from this middleware).
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
return None
|
||||
|
||||
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id is None:
|
||||
config_data = get_config()
|
||||
thread_id = config_data.get("configurable", {}).get("thread_id")
|
||||
if not thread_id:
|
||||
logger.debug("No thread_id in context, skipping memory update")
|
||||
return None
|
||||
|
||||
# Get messages from state
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
logger.debug("No messages in state, skipping memory update")
|
||||
return None
|
||||
|
||||
# Filter to only keep user inputs and final assistant responses
|
||||
filtered_messages = _filter_messages_for_memory(messages)
|
||||
|
||||
# Only queue if there's meaningful conversation
|
||||
# At minimum need one user message and one assistant response
|
||||
user_messages = [m for m in filtered_messages if getattr(m, "type", None) == "human"]
|
||||
assistant_messages = [m for m in filtered_messages if getattr(m, "type", None) == "ai"]
|
||||
|
||||
if not user_messages or not assistant_messages:
|
||||
return None
|
||||
|
||||
# Queue the filtered conversation for memory update
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
queue = get_memory_queue()
|
||||
queue.add(
|
||||
thread_id=thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=self._agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,363 @@
|
||||
"""SandboxAuditMiddleware - bash command security auditing."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import shlex
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import override
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command classification rules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Each pattern is compiled once at import time.
|
||||
_HIGH_RISK_PATTERNS: list[re.Pattern[str]] = [
|
||||
# --- original rules (retained) ---
|
||||
re.compile(r"rm\s+-[^\s]*r[^\s]*\s+(/\*?|~/?\*?|/home\b|/root\b)\s*$"),
|
||||
re.compile(r"dd\s+if="),
|
||||
re.compile(r"mkfs"),
|
||||
re.compile(r"cat\s+/etc/shadow"),
|
||||
re.compile(r">+\s*/etc/"),
|
||||
# --- pipe to sh/bash (generalised, replaces old curl|sh rule) ---
|
||||
re.compile(r"\|\s*(ba)?sh\b"),
|
||||
# --- command substitution (targeted – only dangerous executables) ---
|
||||
re.compile(r"[`$]\(?\s*(curl|wget|bash|sh|python|ruby|perl|base64)"),
|
||||
# --- base64 decode piped to execution ---
|
||||
re.compile(r"base64\s+.*-d.*\|"),
|
||||
# --- overwrite system binaries ---
|
||||
re.compile(r">+\s*(/usr/bin/|/bin/|/sbin/)"),
|
||||
# --- overwrite shell startup files ---
|
||||
re.compile(r">+\s*~/?\.(bashrc|profile|zshrc|bash_profile)"),
|
||||
# --- process environment leakage ---
|
||||
re.compile(r"/proc/[^/]+/environ"),
|
||||
# --- dynamic linker hijack (one-step escalation) ---
|
||||
re.compile(r"\b(LD_PRELOAD|LD_LIBRARY_PATH)\s*="),
|
||||
# --- bash built-in networking (bypasses tool allowlists) ---
|
||||
re.compile(r"/dev/tcp/"),
|
||||
# --- fork bomb ---
|
||||
re.compile(r"\S+\(\)\s*\{[^}]*\|\s*\S+\s*&"), # :(){ :|:& };:
|
||||
re.compile(r"while\s+true.*&\s*done"), # while true; do bash & done
|
||||
]
|
||||
|
||||
_MEDIUM_RISK_PATTERNS: list[re.Pattern[str]] = [
|
||||
re.compile(r"chmod\s+777"),
|
||||
re.compile(r"pip3?\s+install"),
|
||||
re.compile(r"apt(-get)?\s+install"),
|
||||
# sudo/su: no-op under Docker root; warn so LLM is aware
|
||||
re.compile(r"\b(sudo|su)\b"),
|
||||
# PATH modification: long attack chain, warn rather than block
|
||||
re.compile(r"\bPATH\s*="),
|
||||
]
|
||||
|
||||
|
||||
def _split_compound_command(command: str) -> list[str]:
|
||||
"""Split a compound command into sub-commands (quote-aware).
|
||||
|
||||
Scans the raw command string so unquoted shell control operators are
|
||||
recognised even when they are not surrounded by whitespace
|
||||
(e.g. ``safe;rm -rf /`` or ``rm -rf /&&echo ok``). Operators inside
|
||||
quotes are ignored. If the command ends with an unclosed quote or a
|
||||
dangling escape, return the whole command unchanged (fail-closed —
|
||||
safer to classify the unsplit string than silently drop parts).
|
||||
"""
|
||||
parts: list[str] = []
|
||||
current: list[str] = []
|
||||
in_single_quote = False
|
||||
in_double_quote = False
|
||||
escaping = False
|
||||
index = 0
|
||||
|
||||
while index < len(command):
|
||||
char = command[index]
|
||||
|
||||
if escaping:
|
||||
current.append(char)
|
||||
escaping = False
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if char == "\\" and not in_single_quote:
|
||||
current.append(char)
|
||||
escaping = True
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if char == "'" and not in_double_quote:
|
||||
in_single_quote = not in_single_quote
|
||||
current.append(char)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if char == '"' and not in_single_quote:
|
||||
in_double_quote = not in_double_quote
|
||||
current.append(char)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
if not in_single_quote and not in_double_quote:
|
||||
if command.startswith("&&", index) or command.startswith("||", index):
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
current = []
|
||||
index += 2
|
||||
continue
|
||||
if char == ";":
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
current = []
|
||||
index += 1
|
||||
continue
|
||||
|
||||
current.append(char)
|
||||
index += 1
|
||||
|
||||
# Unclosed quote or dangling escape → fail-closed, return whole command
|
||||
if in_single_quote or in_double_quote or escaping:
|
||||
return [command]
|
||||
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
return parts if parts else [command]
|
||||
|
||||
|
||||
def _classify_single_command(command: str) -> str:
|
||||
"""Classify a single (non-compound) command. Return 'block', 'warn', or 'pass'."""
|
||||
normalized = " ".join(command.split())
|
||||
|
||||
for pattern in _HIGH_RISK_PATTERNS:
|
||||
if pattern.search(normalized):
|
||||
return "block"
|
||||
|
||||
# Also try shlex-parsed tokens for high-risk detection
|
||||
try:
|
||||
tokens = shlex.split(command)
|
||||
joined = " ".join(tokens)
|
||||
for pattern in _HIGH_RISK_PATTERNS:
|
||||
if pattern.search(joined):
|
||||
return "block"
|
||||
except ValueError:
|
||||
# shlex.split fails on unclosed quotes — treat as suspicious
|
||||
return "block"
|
||||
|
||||
for pattern in _MEDIUM_RISK_PATTERNS:
|
||||
if pattern.search(normalized):
|
||||
return "warn"
|
||||
|
||||
return "pass"
|
||||
|
||||
|
||||
def _classify_command(command: str) -> str:
|
||||
"""Return 'block', 'warn', or 'pass'.
|
||||
|
||||
Strategy:
|
||||
1. First scan the *whole* raw command against high-risk patterns. This
|
||||
catches structural attacks like ``while true; do bash & done`` or
|
||||
``:(){ :|:& };:`` that span multiple shell statements — splitting them
|
||||
on ``;`` would destroy the pattern context.
|
||||
2. Then split compound commands (e.g. ``cmd1 && cmd2 ; cmd3``) and
|
||||
classify each sub-command independently. The most severe verdict wins.
|
||||
"""
|
||||
# Pass 1: whole-command high-risk scan (catches multi-statement patterns)
|
||||
normalized = " ".join(command.split())
|
||||
for pattern in _HIGH_RISK_PATTERNS:
|
||||
if pattern.search(normalized):
|
||||
return "block"
|
||||
|
||||
# Pass 2: per-sub-command classification
|
||||
sub_commands = _split_compound_command(command)
|
||||
worst = "pass"
|
||||
for sub in sub_commands:
|
||||
verdict = _classify_single_command(sub)
|
||||
if verdict == "block":
|
||||
return "block" # short-circuit: can't get worse
|
||||
if verdict == "warn":
|
||||
worst = "warn"
|
||||
return worst
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SandboxAuditMiddleware(AgentMiddleware[ThreadState]):
|
||||
"""Bash command security auditing middleware.
|
||||
|
||||
For every ``bash`` tool call:
|
||||
1. **Command classification**: regex + shlex analysis grades commands as
|
||||
high-risk (block), medium-risk (warn), or safe (pass).
|
||||
2. **Audit log**: every bash call is recorded as a structured JSON entry
|
||||
via the standard logger (visible in langgraph.log).
|
||||
|
||||
High-risk commands (e.g. ``rm -rf /``, ``curl url | bash``) are blocked:
|
||||
the handler is not called and an error ``ToolMessage`` is returned so the
|
||||
agent loop can continue gracefully.
|
||||
|
||||
Medium-risk commands (e.g. ``pip install``, ``chmod 777``) are executed
|
||||
normally; a warning is appended to the tool result so the LLM is aware.
|
||||
"""
|
||||
|
||||
state_schema = ThreadState
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_thread_id(self, request: ToolCallRequest) -> str | None:
|
||||
runtime = request.runtime # ToolRuntime; may be None-like in tests
|
||||
if runtime is None:
|
||||
return None
|
||||
ctx = getattr(runtime, "context", None) or {}
|
||||
thread_id = ctx.get("thread_id") if isinstance(ctx, dict) else None
|
||||
if thread_id is None:
|
||||
cfg = getattr(runtime, "config", None) or {}
|
||||
thread_id = cfg.get("configurable", {}).get("thread_id")
|
||||
return thread_id
|
||||
|
||||
_AUDIT_COMMAND_LIMIT = 200
|
||||
|
||||
def _write_audit(self, thread_id: str | None, command: str, verdict: str, *, truncate: bool = False) -> None:
|
||||
audited_command = command
|
||||
if truncate and len(command) > self._AUDIT_COMMAND_LIMIT:
|
||||
audited_command = f"{command[: self._AUDIT_COMMAND_LIMIT]}... ({len(command)} chars)"
|
||||
record = {
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"thread_id": thread_id or "unknown",
|
||||
"command": audited_command,
|
||||
"verdict": verdict,
|
||||
}
|
||||
logger.info("[SandboxAudit] %s", json.dumps(record, ensure_ascii=False))
|
||||
|
||||
def _build_block_message(self, request: ToolCallRequest, reason: str) -> ToolMessage:
|
||||
tool_call_id = str(request.tool_call.get("id") or "missing_id")
|
||||
return ToolMessage(
|
||||
content=f"Command blocked: {reason}. Please use a safer alternative approach.",
|
||||
tool_call_id=tool_call_id,
|
||||
name="bash",
|
||||
status="error",
|
||||
)
|
||||
|
||||
def _append_warn_to_result(self, result: ToolMessage | Command, command: str) -> ToolMessage | Command:
|
||||
"""Append a warning note to the tool result for medium-risk commands."""
|
||||
if not isinstance(result, ToolMessage):
|
||||
return result
|
||||
warning = f"\n\n⚠️ Warning: `{command}` is a medium-risk command that may modify the runtime environment."
|
||||
if isinstance(result.content, list):
|
||||
new_content = list(result.content) + [{"type": "text", "text": warning}]
|
||||
else:
|
||||
new_content = str(result.content) + warning
|
||||
return ToolMessage(
|
||||
content=new_content,
|
||||
tool_call_id=result.tool_call_id,
|
||||
name=result.name,
|
||||
status=result.status,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Input sanitisation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# Normal bash commands rarely exceed a few hundred characters. 10 000 is
|
||||
# well above any legitimate use case yet a tiny fraction of Linux ARG_MAX.
|
||||
# Anything longer is almost certainly a payload injection or base64-encoded
|
||||
# attack string.
|
||||
_MAX_COMMAND_LENGTH = 10_000
|
||||
|
||||
def _validate_input(self, command: str) -> str | None:
|
||||
"""Return ``None`` if *command* is acceptable, else a rejection reason."""
|
||||
if not command.strip():
|
||||
return "empty command"
|
||||
if len(command) > self._MAX_COMMAND_LENGTH:
|
||||
return "command too long"
|
||||
if "\x00" in command:
|
||||
return "null byte detected"
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core logic (shared between sync and async paths)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _pre_process(self, request: ToolCallRequest) -> tuple[str, str | None, str, str | None]:
|
||||
"""
|
||||
Returns (command, thread_id, verdict, reject_reason).
|
||||
verdict is 'block', 'warn', or 'pass'.
|
||||
reject_reason is non-None only for input sanitisation rejections.
|
||||
"""
|
||||
args = request.tool_call.get("args", {})
|
||||
raw_command = args.get("command")
|
||||
command = raw_command if isinstance(raw_command, str) else ""
|
||||
thread_id = self._get_thread_id(request)
|
||||
|
||||
# ① input sanitisation — reject malformed input before regex analysis
|
||||
reject_reason = self._validate_input(command)
|
||||
if reject_reason:
|
||||
self._write_audit(thread_id, command, "block", truncate=True)
|
||||
logger.warning("[SandboxAudit] INVALID INPUT thread=%s reason=%s", thread_id, reject_reason)
|
||||
return command, thread_id, "block", reject_reason
|
||||
|
||||
# ② classify command
|
||||
verdict = _classify_command(command)
|
||||
|
||||
# ③ audit log
|
||||
self._write_audit(thread_id, command, verdict)
|
||||
|
||||
if verdict == "block":
|
||||
logger.warning("[SandboxAudit] BLOCKED thread=%s cmd=%r", thread_id, command)
|
||||
elif verdict == "warn":
|
||||
logger.warning("[SandboxAudit] WARN (medium-risk) thread=%s cmd=%r", thread_id, command)
|
||||
|
||||
return command, thread_id, verdict, None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# wrap_tool_call hooks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
if request.tool_call.get("name") != "bash":
|
||||
return handler(request)
|
||||
|
||||
command, _, verdict, reject_reason = self._pre_process(request)
|
||||
if verdict == "block":
|
||||
reason = reject_reason or "security violation detected"
|
||||
return self._build_block_message(request, reason)
|
||||
result = handler(request)
|
||||
if verdict == "warn":
|
||||
result = self._append_warn_to_result(result, command)
|
||||
return result
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
if request.tool_call.get("name") != "bash":
|
||||
return await handler(request)
|
||||
|
||||
command, _, verdict, reject_reason = self._pre_process(request)
|
||||
if verdict == "block":
|
||||
reason = reject_reason or "security violation detected"
|
||||
return self._build_block_message(request, reason)
|
||||
result = await handler(request)
|
||||
if verdict == "warn":
|
||||
result = self._append_warn_to_result(result, command)
|
||||
return result
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Middleware to enforce maximum concurrent subagent tool calls per model response."""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Valid range for max_concurrent_subagents
|
||||
MIN_SUBAGENT_LIMIT = 2
|
||||
MAX_SUBAGENT_LIMIT = 4
|
||||
|
||||
|
||||
def _clamp_subagent_limit(value: int) -> int:
|
||||
"""Clamp subagent limit to valid range [2, 4]."""
|
||||
return max(MIN_SUBAGENT_LIMIT, min(MAX_SUBAGENT_LIMIT, value))
|
||||
|
||||
|
||||
class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Truncates excess 'task' tool calls from a single model response.
|
||||
|
||||
When an LLM generates more than max_concurrent parallel task tool calls
|
||||
in one response, this middleware keeps only the first max_concurrent and
|
||||
discards the rest. This is more reliable than prompt-based limits.
|
||||
|
||||
Args:
|
||||
max_concurrent: Maximum number of concurrent subagent calls allowed.
|
||||
Defaults to MAX_CONCURRENT_SUBAGENTS (3). Clamped to [2, 4].
|
||||
"""
|
||||
|
||||
def __init__(self, max_concurrent: int = MAX_CONCURRENT_SUBAGENTS):
|
||||
super().__init__()
|
||||
self.max_concurrent = _clamp_subagent_limit(max_concurrent)
|
||||
|
||||
def _truncate_task_calls(self, state: AgentState) -> dict | None:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_msg = messages[-1]
|
||||
if getattr(last_msg, "type", None) != "ai":
|
||||
return None
|
||||
|
||||
tool_calls = getattr(last_msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
return None
|
||||
|
||||
# Count task tool calls
|
||||
task_indices = [i for i, tc in enumerate(tool_calls) if tc.get("name") == "task"]
|
||||
if len(task_indices) <= self.max_concurrent:
|
||||
return None
|
||||
|
||||
# Build set of indices to drop (excess task calls beyond the limit)
|
||||
indices_to_drop = set(task_indices[self.max_concurrent :])
|
||||
truncated_tool_calls = [tc for i, tc in enumerate(tool_calls) if i not in indices_to_drop]
|
||||
|
||||
dropped_count = len(indices_to_drop)
|
||||
logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})")
|
||||
|
||||
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
|
||||
return {"messages": [updated_msg]}
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._truncate_task_calls(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._truncate_task_calls(state)
|
||||
@@ -0,0 +1,99 @@
|
||||
import logging
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThreadDataMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
|
||||
|
||||
class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
"""Create thread data directories for each thread execution.
|
||||
|
||||
Creates the following directory structure:
|
||||
- {base_dir}/threads/{thread_id}/user-data/workspace
|
||||
- {base_dir}/threads/{thread_id}/user-data/uploads
|
||||
- {base_dir}/threads/{thread_id}/user-data/outputs
|
||||
|
||||
Lifecycle Management:
|
||||
- With lazy_init=True (default): Only compute paths, directories created on-demand
|
||||
- With lazy_init=False: Eagerly create directories in before_agent()
|
||||
"""
|
||||
|
||||
state_schema = ThreadDataMiddlewareState
|
||||
|
||||
def __init__(self, base_dir: str | None = None, lazy_init: bool = True):
|
||||
"""Initialize the middleware.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for thread data. Defaults to Paths resolution.
|
||||
lazy_init: If True, defer directory creation until needed.
|
||||
If False, create directories eagerly in before_agent().
|
||||
Default is True for optimal performance.
|
||||
"""
|
||||
super().__init__()
|
||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||
self._lazy_init = lazy_init
|
||||
|
||||
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
|
||||
"""Get the paths for a thread's data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
|
||||
Returns:
|
||||
Dictionary with workspace_path, uploads_path, and outputs_path.
|
||||
"""
|
||||
return {
|
||||
"workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
|
||||
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
|
||||
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
|
||||
}
|
||||
|
||||
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
|
||||
"""Create the thread data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
|
||||
Returns:
|
||||
Dictionary with the created directory paths.
|
||||
"""
|
||||
self._paths.ensure_thread_dirs(thread_id)
|
||||
return self._get_thread_paths(thread_id)
|
||||
|
||||
@override
|
||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
context = runtime.context or {}
|
||||
thread_id = context.get("thread_id")
|
||||
if thread_id is None:
|
||||
config = get_config()
|
||||
thread_id = config.get("configurable", {}).get("thread_id")
|
||||
|
||||
if thread_id is None:
|
||||
raise ValueError("Thread ID is required in runtime context or config.configurable")
|
||||
|
||||
if self._lazy_init:
|
||||
# Lazy initialization: only compute paths, don't create directories
|
||||
paths = self._get_thread_paths(thread_id)
|
||||
else:
|
||||
# Eager initialization: create directories immediately
|
||||
paths = self._create_thread_directories(thread_id)
|
||||
logger.debug("Created thread data directories for thread %s", thread_id)
|
||||
|
||||
return {
|
||||
"thread_data": {
|
||||
**paths,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
"""Middleware for automatic thread title generation."""
|
||||
|
||||
import logging
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.title_config import get_title_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TitleMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
title: NotRequired[str | None]
|
||||
|
||||
|
||||
class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
"""Automatically generate a title for the thread after the first user message."""
|
||||
|
||||
state_schema = TitleMiddlewareState
|
||||
|
||||
def _normalize_content(self, content: object) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = [self._normalize_content(item) for item in content]
|
||||
return "\n".join(part for part in parts if part)
|
||||
|
||||
if isinstance(content, dict):
|
||||
text_value = content.get("text")
|
||||
if isinstance(text_value, str):
|
||||
return text_value
|
||||
|
||||
nested_content = content.get("content")
|
||||
if nested_content is not None:
|
||||
return self._normalize_content(nested_content)
|
||||
|
||||
return ""
|
||||
|
||||
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
||||
"""Check if we should generate a title for this thread."""
|
||||
config = get_title_config()
|
||||
if not config.enabled:
|
||||
return False
|
||||
|
||||
# Check if thread already has a title in state
|
||||
if state.get("title"):
|
||||
return False
|
||||
|
||||
# Check if this is the first turn (has at least one user message and one assistant response)
|
||||
messages = state.get("messages", [])
|
||||
if len(messages) < 2:
|
||||
return False
|
||||
|
||||
# Count user and assistant messages
|
||||
user_messages = [m for m in messages if m.type == "human"]
|
||||
assistant_messages = [m for m in messages if m.type == "ai"]
|
||||
|
||||
# Generate title after first complete exchange
|
||||
return len(user_messages) == 1 and len(assistant_messages) >= 1
|
||||
|
||||
def _build_title_prompt(self, state: TitleMiddlewareState) -> tuple[str, str]:
|
||||
"""Extract user/assistant messages and build the title prompt.
|
||||
|
||||
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
|
||||
"""
|
||||
config = get_title_config()
|
||||
messages = state.get("messages", [])
|
||||
|
||||
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
||||
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
|
||||
|
||||
user_msg = self._normalize_content(user_msg_content)
|
||||
assistant_msg = self._normalize_content(assistant_msg_content)
|
||||
|
||||
prompt = config.prompt_template.format(
|
||||
max_words=config.max_words,
|
||||
user_msg=user_msg[:500],
|
||||
assistant_msg=assistant_msg[:500],
|
||||
)
|
||||
return prompt, user_msg
|
||||
|
||||
def _parse_title(self, content: object) -> str:
|
||||
"""Normalize model output into a clean title string."""
|
||||
config = get_title_config()
|
||||
title_content = self._normalize_content(content)
|
||||
title = title_content.strip().strip('"').strip("'")
|
||||
return title[: config.max_chars] if len(title) > config.max_chars else title
|
||||
|
||||
def _fallback_title(self, user_msg: str) -> str:
|
||||
config = get_title_config()
|
||||
fallback_chars = min(config.max_chars, 50)
|
||||
if len(user_msg) > fallback_chars:
|
||||
return user_msg[:fallback_chars].rstrip() + "..."
|
||||
return user_msg if user_msg else "New Conversation"
|
||||
|
||||
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Generate a local fallback title without blocking on an LLM call."""
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
_, user_msg = self._build_title_prompt(state)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Generate a title asynchronously and fall back locally on failure."""
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
config = get_title_config()
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
|
||||
try:
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
response = await model.ainvoke(prompt)
|
||||
title = self._parse_title(response.content)
|
||||
if title:
|
||||
return {"title": title}
|
||||
except Exception:
|
||||
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
|
||||
@override
|
||||
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
return self._generate_title_result(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
return await self._agenerate_title_result(state)
|
||||
@@ -0,0 +1,100 @@
|
||||
"""Middleware that extends TodoListMiddleware with context-loss detection.
|
||||
|
||||
When the message history is truncated (e.g., by SummarizationMiddleware), the
|
||||
original `write_todos` tool call and its ToolMessage can be scrolled out of the
|
||||
active context window. This middleware detects that situation and injects a
|
||||
reminder message so the model still knows about the outstanding todo list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
from langchain.agents.middleware.todo import PlanningState, Todo
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
def _todos_in_messages(messages: list[Any]) -> bool:
|
||||
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
|
||||
for msg in messages:
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("name") == "write_todos":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _reminder_in_messages(messages: list[Any]) -> bool:
|
||||
"""Return True if a todo_reminder HumanMessage is already present in *messages*."""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_reminder":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _format_todos(todos: list[Todo]) -> str:
|
||||
"""Format a list of Todo items into a human-readable string."""
|
||||
lines: list[str] = []
|
||||
for todo in todos:
|
||||
status = todo.get("status", "pending")
|
||||
content = todo.get("content", "")
|
||||
lines.append(f"- [{status}] {content}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class TodoMiddleware(TodoListMiddleware):
|
||||
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
||||
|
||||
When the original `write_todos` tool call has been truncated from the message
|
||||
history (e.g., after summarization), the model loses awareness of the current
|
||||
todo list. This middleware detects that gap in `before_model` / `abefore_model`
|
||||
and injects a reminder message so the model can continue tracking progress.
|
||||
"""
|
||||
|
||||
@override
|
||||
def before_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
runtime: Runtime, # noqa: ARG002
|
||||
) -> dict[str, Any] | None:
|
||||
"""Inject a todo-list reminder when write_todos has left the context window."""
|
||||
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
|
||||
if not todos:
|
||||
return None
|
||||
|
||||
messages = state.get("messages") or []
|
||||
if _todos_in_messages(messages):
|
||||
# write_todos is still visible in context — nothing to do.
|
||||
return None
|
||||
|
||||
if _reminder_in_messages(messages):
|
||||
# A reminder was already injected and hasn't been truncated yet.
|
||||
return None
|
||||
|
||||
# The todo list exists in state but the original write_todos call is gone.
|
||||
# Inject a reminder as a HumanMessage so the model stays aware.
|
||||
formatted = _format_todos(todos)
|
||||
reminder = HumanMessage(
|
||||
name="todo_reminder",
|
||||
content=(
|
||||
"<system_reminder>\n"
|
||||
"Your todo list from earlier is no longer visible in the current context window, "
|
||||
"but it is still active. Here is the current state:\n\n"
|
||||
f"{formatted}\n\n"
|
||||
"Continue tracking and updating this todo list as you work. "
|
||||
"Call `write_todos` whenever the status of any item changes.\n"
|
||||
"</system_reminder>"
|
||||
),
|
||||
)
|
||||
return {"messages": [reminder]}
|
||||
|
||||
@override
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async version of before_model."""
|
||||
return self.before_model(state, runtime)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Middleware for logging LLM token usage."""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenUsageMiddleware(AgentMiddleware):
|
||||
"""Logs token usage from model response usage_metadata."""
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._log_usage(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._log_usage(state)
|
||||
|
||||
def _log_usage(self, state: AgentState) -> None:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1]
|
||||
usage = getattr(last, "usage_metadata", None)
|
||||
if usage:
|
||||
logger.info(
|
||||
"LLM token usage: input=%s output=%s total=%s",
|
||||
usage.get("input_tokens", "?"),
|
||||
usage.get("output_tokens", "?"),
|
||||
usage.get("total_tokens", "?"),
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,143 @@
|
||||
"""Tool error handling middleware and shared runtime middleware builders."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
||||
|
||||
|
||||
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Convert tool exceptions into error ToolMessages so the run can continue."""
|
||||
|
||||
def _build_error_message(self, request: ToolCallRequest, exc: Exception) -> ToolMessage:
|
||||
tool_name = str(request.tool_call.get("name") or "unknown_tool")
|
||||
tool_call_id = str(request.tool_call.get("id") or _MISSING_TOOL_CALL_ID)
|
||||
detail = str(exc).strip() or exc.__class__.__name__
|
||||
if len(detail) > 500:
|
||||
detail = detail[:497] + "..."
|
||||
|
||||
content = f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. Continue with available context, or choose an alternative tool."
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
try:
|
||||
return handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
||||
return self._build_error_message(request, exc)
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
try:
|
||||
return await handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
||||
return self._build_error_message(request, exc)
|
||||
|
||||
|
||||
def _build_runtime_middlewares(
|
||||
*,
|
||||
include_uploads: bool,
|
||||
include_dangling_tool_call_patch: bool,
|
||||
lazy_init: bool = True,
|
||||
) -> list[AgentMiddleware]:
|
||||
"""Build shared base middlewares for agent execution."""
|
||||
from deerflow.agents.middlewares.llm_error_handling_middleware import LLMErrorHandlingMiddleware
|
||||
from deerflow.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from deerflow.sandbox.middleware import SandboxMiddleware
|
||||
|
||||
middlewares: list[AgentMiddleware] = [
|
||||
ThreadDataMiddleware(lazy_init=lazy_init),
|
||||
SandboxMiddleware(lazy_init=lazy_init),
|
||||
]
|
||||
|
||||
if include_uploads:
|
||||
from deerflow.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
|
||||
middlewares.insert(1, UploadsMiddleware())
|
||||
|
||||
if include_dangling_tool_call_patch:
|
||||
from deerflow.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
|
||||
|
||||
middlewares.append(DanglingToolCallMiddleware())
|
||||
|
||||
middlewares.append(LLMErrorHandlingMiddleware())
|
||||
|
||||
# Guardrail middleware (if configured)
|
||||
from deerflow.config.guardrails_config import get_guardrails_config
|
||||
|
||||
guardrails_config = get_guardrails_config()
|
||||
if guardrails_config.enabled and guardrails_config.provider:
|
||||
import inspect
|
||||
|
||||
from deerflow.guardrails.middleware import GuardrailMiddleware
|
||||
from deerflow.reflection import resolve_variable
|
||||
|
||||
provider_cls = resolve_variable(guardrails_config.provider.use)
|
||||
provider_kwargs = dict(guardrails_config.provider.config) if guardrails_config.provider.config else {}
|
||||
# Pass framework hint if the provider accepts it (e.g. for config discovery).
|
||||
# Built-in providers like AllowlistProvider don't need it, so only inject
|
||||
# when the constructor accepts 'framework' or '**kwargs'.
|
||||
if "framework" not in provider_kwargs:
|
||||
try:
|
||||
sig = inspect.signature(provider_cls.__init__)
|
||||
if "framework" in sig.parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):
|
||||
provider_kwargs["framework"] = "deerflow"
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
provider = provider_cls(**provider_kwargs)
|
||||
middlewares.append(GuardrailMiddleware(provider, fail_closed=guardrails_config.fail_closed, passport=guardrails_config.passport))
|
||||
|
||||
from deerflow.agents.middlewares.sandbox_audit_middleware import SandboxAuditMiddleware
|
||||
|
||||
middlewares.append(SandboxAuditMiddleware())
|
||||
middlewares.append(ToolErrorHandlingMiddleware())
|
||||
return middlewares
|
||||
|
||||
|
||||
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
include_uploads=True,
|
||||
include_dangling_tool_call_patch=True,
|
||||
lazy_init=lazy_init,
|
||||
)
|
||||
|
||||
|
||||
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
include_uploads=False,
|
||||
include_dangling_tool_call_patch=True,
|
||||
lazy_init=lazy_init,
|
||||
)
|
||||
@@ -0,0 +1,293 @@
|
||||
"""Middleware to inject uploaded files information into agent context."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.utils.file_conversion import extract_outline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_OUTLINE_PREVIEW_LINES = 5
|
||||
|
||||
|
||||
def _extract_outline_for_file(file_path: Path) -> tuple[list[dict], list[str]]:
|
||||
"""Return the document outline and fallback preview for *file_path*.
|
||||
|
||||
Looks for a sibling ``<stem>.md`` file produced by the upload conversion
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
(outline, preview) where:
|
||||
- outline: list of ``{title, line}`` dicts (plus optional sentinel).
|
||||
Empty when no headings are found or no .md exists.
|
||||
- preview: first few non-empty lines of the .md, used as a content
|
||||
anchor when outline is empty so the agent has some context.
|
||||
Empty when outline is non-empty (no fallback needed).
|
||||
"""
|
||||
md_path = file_path.with_suffix(".md")
|
||||
if not md_path.is_file():
|
||||
return [], []
|
||||
|
||||
outline = extract_outline(md_path)
|
||||
if outline:
|
||||
logger.debug("Extracted %d outline entries from %s", len(outline), file_path.name)
|
||||
return outline, []
|
||||
|
||||
# outline is empty — read the first few non-empty lines as a content preview
|
||||
preview: list[str] = []
|
||||
try:
|
||||
with md_path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
stripped = line.strip()
|
||||
if stripped:
|
||||
preview.append(stripped)
|
||||
if len(preview) >= _OUTLINE_PREVIEW_LINES:
|
||||
break
|
||||
except Exception:
|
||||
logger.debug("Failed to read preview lines from %s", md_path, exc_info=True)
|
||||
return [], preview
|
||||
|
||||
|
||||
class UploadsMiddlewareState(AgentState):
|
||||
"""State schema for uploads middleware."""
|
||||
|
||||
uploaded_files: NotRequired[list[dict] | None]
|
||||
|
||||
|
||||
class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
"""Middleware to inject uploaded files information into the agent context.
|
||||
|
||||
Reads file metadata from the current message's additional_kwargs.files
|
||||
(set by the frontend after upload) and prepends an <uploaded_files> block
|
||||
to the last human message so the model knows which files are available.
|
||||
"""
|
||||
|
||||
state_schema = UploadsMiddlewareState
|
||||
|
||||
def __init__(self, base_dir: str | None = None):
|
||||
"""Initialize the middleware.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for thread data. Defaults to Paths resolution.
|
||||
"""
|
||||
super().__init__()
|
||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||
|
||||
def _format_file_entry(self, file: dict, lines: list[str]) -> None:
|
||||
"""Append a single file entry (name, size, path, optional outline) to lines."""
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
outline = file.get("outline") or []
|
||||
if outline:
|
||||
truncated = outline[-1].get("truncated", False)
|
||||
visible = [e for e in outline if not e.get("truncated")]
|
||||
lines.append(" Document outline (use `read_file` with line ranges to read sections):")
|
||||
for entry in visible:
|
||||
lines.append(f" L{entry['line']}: {entry['title']}")
|
||||
if truncated:
|
||||
lines.append(f" ... (showing first {len(visible)} headings; use `read_file` to explore further)")
|
||||
else:
|
||||
preview = file.get("outline_preview") or []
|
||||
if preview:
|
||||
lines.append(" No structural headings detected. Document begins with:")
|
||||
for text in preview:
|
||||
lines.append(f" > {text}")
|
||||
lines.append(" Use `grep` to search for keywords (e.g. `grep(pattern='keyword', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("")
|
||||
|
||||
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
|
||||
"""Create a formatted message listing uploaded files.
|
||||
|
||||
Args:
|
||||
new_files: Files uploaded in the current message.
|
||||
historical_files: Files uploaded in previous messages.
|
||||
Each file dict may contain an optional ``outline`` key — a list of
|
||||
``{title, line}`` dicts extracted from the converted Markdown file.
|
||||
|
||||
Returns:
|
||||
Formatted string inside <uploaded_files> tags.
|
||||
"""
|
||||
lines = ["<uploaded_files>"]
|
||||
|
||||
lines.append("The following files were uploaded in this message:")
|
||||
lines.append("")
|
||||
if new_files:
|
||||
for file in new_files:
|
||||
self._format_file_entry(file, lines)
|
||||
else:
|
||||
lines.append("(empty)")
|
||||
lines.append("")
|
||||
|
||||
if historical_files:
|
||||
lines.append("The following files were uploaded in previous messages and are still available:")
|
||||
lines.append("")
|
||||
for file in historical_files:
|
||||
self._format_file_entry(file, lines)
|
||||
|
||||
lines.append("To work with these files:")
|
||||
lines.append("- Read from the file first — use the outline line numbers and `read_file` to locate relevant sections.")
|
||||
lines.append("- Use `grep` to search for keywords when you are not sure which section to look at")
|
||||
lines.append(" (e.g. `grep(pattern='revenue', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("- Use `glob` to find files by name pattern")
|
||||
lines.append(" (e.g. `glob(pattern='**/*.md', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("- Only fall back to web search if the file content is clearly insufficient to answer the question.")
|
||||
lines.append("</uploaded_files>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _files_from_kwargs(self, message: HumanMessage, uploads_dir: Path | None = None) -> list[dict] | None:
|
||||
"""Extract file info from message additional_kwargs.files.
|
||||
|
||||
The frontend sends uploaded file metadata in additional_kwargs.files
|
||||
after a successful upload. Each entry has: filename, size (bytes),
|
||||
path (virtual path), status.
|
||||
|
||||
Args:
|
||||
message: The human message to inspect.
|
||||
uploads_dir: Physical uploads directory used to verify file existence.
|
||||
When provided, entries whose files no longer exist are skipped.
|
||||
|
||||
Returns:
|
||||
List of file dicts with virtual paths, or None if the field is absent or empty.
|
||||
"""
|
||||
kwargs_files = (message.additional_kwargs or {}).get("files")
|
||||
if not isinstance(kwargs_files, list) or not kwargs_files:
|
||||
return None
|
||||
|
||||
files = []
|
||||
for f in kwargs_files:
|
||||
if not isinstance(f, dict):
|
||||
continue
|
||||
filename = f.get("filename") or ""
|
||||
if not filename or Path(filename).name != filename:
|
||||
continue
|
||||
if uploads_dir is not None and not (uploads_dir / filename).is_file():
|
||||
continue
|
||||
files.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"size": int(f.get("size") or 0),
|
||||
"path": f"/mnt/user-data/uploads/{filename}",
|
||||
"extension": Path(filename).suffix,
|
||||
}
|
||||
)
|
||||
return files if files else None
|
||||
|
||||
@override
|
||||
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Inject uploaded files information before agent execution.
|
||||
|
||||
New files come from the current message's additional_kwargs.files.
|
||||
Historical files are scanned from the thread's uploads directory,
|
||||
excluding the new ones.
|
||||
|
||||
Prepends <uploaded_files> context to the last human message content.
|
||||
The original additional_kwargs (including files metadata) is preserved
|
||||
on the updated message so the frontend can read it from the stream.
|
||||
|
||||
Args:
|
||||
state: Current agent state.
|
||||
runtime: Runtime context containing thread_id.
|
||||
|
||||
Returns:
|
||||
State updates including uploaded files list.
|
||||
"""
|
||||
messages = list(state.get("messages", []))
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message_index = len(messages) - 1
|
||||
last_message = messages[last_message_index]
|
||||
|
||||
if not isinstance(last_message, HumanMessage):
|
||||
return None
|
||||
|
||||
# Resolve uploads directory for existence checks
|
||||
thread_id = (runtime.context or {}).get("thread_id")
|
||||
if thread_id is None:
|
||||
try:
|
||||
from langgraph.config import get_config
|
||||
|
||||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
||||
except RuntimeError:
|
||||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||||
|
||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
|
||||
|
||||
# Collect historical files from the uploads directory (all except the new ones)
|
||||
new_filenames = {f["filename"] for f in new_files}
|
||||
historical_files: list[dict] = []
|
||||
if uploads_dir and uploads_dir.exists():
|
||||
for file_path in sorted(uploads_dir.iterdir()):
|
||||
if file_path.is_file() and file_path.name not in new_filenames:
|
||||
stat = file_path.stat()
|
||||
outline, preview = _extract_outline_for_file(file_path)
|
||||
historical_files.append(
|
||||
{
|
||||
"filename": file_path.name,
|
||||
"size": stat.st_size,
|
||||
"path": f"/mnt/user-data/uploads/{file_path.name}",
|
||||
"extension": file_path.suffix,
|
||||
"outline": outline,
|
||||
"outline_preview": preview,
|
||||
}
|
||||
)
|
||||
|
||||
# Attach outlines to new files as well
|
||||
if uploads_dir:
|
||||
for file in new_files:
|
||||
phys_path = uploads_dir / file["filename"]
|
||||
outline, preview = _extract_outline_for_file(phys_path)
|
||||
file["outline"] = outline
|
||||
file["outline_preview"] = preview
|
||||
|
||||
if not new_files and not historical_files:
|
||||
return None
|
||||
|
||||
logger.debug(f"New files: {[f['filename'] for f in new_files]}, historical: {[f['filename'] for f in historical_files]}")
|
||||
|
||||
# Create files message and prepend to the last human message content
|
||||
files_message = self._create_files_message(new_files, historical_files)
|
||||
|
||||
# Extract original content - handle both string and list formats
|
||||
original_content = last_message.content
|
||||
if isinstance(original_content, str):
|
||||
# Simple case: string content, just prepend files message
|
||||
updated_content = f"{files_message}\n\n{original_content}"
|
||||
elif isinstance(original_content, list):
|
||||
# Complex case: list content (multimodal), preserve all blocks
|
||||
# Prepend files message as the first text block
|
||||
files_block = {"type": "text", "text": f"{files_message}\n\n"}
|
||||
# Keep all original blocks (including images)
|
||||
updated_content = [files_block, *original_content]
|
||||
else:
|
||||
# Other types, preserve as-is
|
||||
updated_content = original_content
|
||||
|
||||
# Create new message with combined content.
|
||||
# Preserve additional_kwargs (including files metadata) so the frontend
|
||||
# can read structured file info from the streamed message.
|
||||
updated_message = HumanMessage(
|
||||
content=updated_content,
|
||||
id=last_message.id,
|
||||
additional_kwargs=last_message.additional_kwargs,
|
||||
)
|
||||
|
||||
messages[last_message_index] = updated_message
|
||||
|
||||
return {
|
||||
"uploaded_files": new_files,
|
||||
"messages": messages,
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
"""Middleware for injecting image details into conversation before LLM call."""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ViewImageMiddlewareState(ThreadState):
|
||||
"""Reuse the thread state so reducer-backed keys keep their annotations."""
|
||||
|
||||
|
||||
class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]):
|
||||
"""Injects image details as a human message before LLM calls when view_image tools have completed.
|
||||
|
||||
This middleware:
|
||||
1. Runs before each LLM call
|
||||
2. Checks if the last assistant message contains view_image tool calls
|
||||
3. Verifies all tool calls in that message have been completed (have corresponding ToolMessages)
|
||||
4. If conditions are met, creates a human message with all viewed image details (including base64 data)
|
||||
5. Adds the message to state so the LLM can see and analyze the images
|
||||
|
||||
This enables the LLM to automatically receive and analyze images that were loaded via view_image tool,
|
||||
without requiring explicit user prompts to describe the images.
|
||||
"""
|
||||
|
||||
state_schema = ViewImageMiddlewareState
|
||||
|
||||
def _get_last_assistant_message(self, messages: list) -> AIMessage | None:
|
||||
"""Get the last assistant message from the message list.
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
|
||||
Returns:
|
||||
Last AIMessage or None if not found
|
||||
"""
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
return msg
|
||||
return None
|
||||
|
||||
def _has_view_image_tool(self, message: AIMessage) -> bool:
|
||||
"""Check if the assistant message contains view_image tool calls.
|
||||
|
||||
Args:
|
||||
message: Assistant message to check
|
||||
|
||||
Returns:
|
||||
True if message contains view_image tool calls
|
||||
"""
|
||||
if not hasattr(message, "tool_calls") or not message.tool_calls:
|
||||
return False
|
||||
|
||||
return any(tool_call.get("name") == "view_image" for tool_call in message.tool_calls)
|
||||
|
||||
def _all_tools_completed(self, messages: list, assistant_msg: AIMessage) -> bool:
|
||||
"""Check if all tool calls in the assistant message have been completed.
|
||||
|
||||
Args:
|
||||
messages: List of all messages
|
||||
assistant_msg: The assistant message containing tool calls
|
||||
|
||||
Returns:
|
||||
True if all tool calls have corresponding ToolMessages
|
||||
"""
|
||||
if not hasattr(assistant_msg, "tool_calls") or not assistant_msg.tool_calls:
|
||||
return False
|
||||
|
||||
# Get all tool call IDs from the assistant message
|
||||
tool_call_ids = {tool_call.get("id") for tool_call in assistant_msg.tool_calls if tool_call.get("id")}
|
||||
|
||||
# Find the index of the assistant message
|
||||
try:
|
||||
assistant_idx = messages.index(assistant_msg)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Get all ToolMessages after the assistant message
|
||||
completed_tool_ids = set()
|
||||
for msg in messages[assistant_idx + 1 :]:
|
||||
if isinstance(msg, ToolMessage) and msg.tool_call_id:
|
||||
completed_tool_ids.add(msg.tool_call_id)
|
||||
|
||||
# Check if all tool calls have been completed
|
||||
return tool_call_ids.issubset(completed_tool_ids)
|
||||
|
||||
def _create_image_details_message(self, state: ViewImageMiddlewareState) -> list[str | dict]:
|
||||
"""Create a formatted message with all viewed image details.
|
||||
|
||||
Args:
|
||||
state: Current state containing viewed_images
|
||||
|
||||
Returns:
|
||||
List of content blocks (text and images) for the HumanMessage
|
||||
"""
|
||||
viewed_images = state.get("viewed_images", {})
|
||||
if not viewed_images:
|
||||
# Return a properly formatted text block, not a plain string array
|
||||
return [{"type": "text", "text": "No images have been viewed."}]
|
||||
|
||||
# Build the message with image information
|
||||
content_blocks: list[str | dict] = [{"type": "text", "text": "Here are the images you've viewed:"}]
|
||||
|
||||
for image_path, image_data in viewed_images.items():
|
||||
mime_type = image_data.get("mime_type", "unknown")
|
||||
base64_data = image_data.get("base64", "")
|
||||
|
||||
# Add text description
|
||||
content_blocks.append({"type": "text", "text": f"\n- **{image_path}** ({mime_type})"})
|
||||
|
||||
# Add the actual image data so LLM can "see" it
|
||||
if base64_data:
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{mime_type};base64,{base64_data}"},
|
||||
}
|
||||
)
|
||||
|
||||
return content_blocks
|
||||
|
||||
def _should_inject_image_message(self, state: ViewImageMiddlewareState) -> bool:
|
||||
"""Determine if we should inject an image details message.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
|
||||
Returns:
|
||||
True if we should inject the message
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
# Get the last assistant message
|
||||
last_assistant_msg = self._get_last_assistant_message(messages)
|
||||
if not last_assistant_msg:
|
||||
return False
|
||||
|
||||
# Check if it has view_image tool calls
|
||||
if not self._has_view_image_tool(last_assistant_msg):
|
||||
return False
|
||||
|
||||
# Check if all tools have been completed
|
||||
if not self._all_tools_completed(messages, last_assistant_msg):
|
||||
return False
|
||||
|
||||
# Check if we've already added an image details message
|
||||
# Look for a human message after the last assistant message that contains image details
|
||||
assistant_idx = messages.index(last_assistant_msg)
|
||||
for msg in messages[assistant_idx + 1 :]:
|
||||
if isinstance(msg, HumanMessage):
|
||||
content_str = str(msg.content)
|
||||
if "Here are the images you've viewed" in content_str or "Here are the details of the images you've viewed" in content_str:
|
||||
# Already added, don't add again
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _inject_image_message(self, state: ViewImageMiddlewareState) -> dict | None:
|
||||
"""Internal helper to inject image details message.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
if not self._should_inject_image_message(state):
|
||||
return None
|
||||
|
||||
# Create the image details message with text and image content
|
||||
image_content = self._create_image_details_message(state)
|
||||
|
||||
# Create a new human message with mixed content (text + images)
|
||||
human_msg = HumanMessage(content=image_content)
|
||||
|
||||
logger.debug("Injecting image details message with images before LLM call")
|
||||
|
||||
# Return state update with the new message
|
||||
return {"messages": [human_msg]}
|
||||
|
||||
@override
|
||||
def before_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Inject image details message before LLM call if view_image tools have completed (sync version).
|
||||
|
||||
This runs before each LLM call, checking if the previous turn included view_image
|
||||
tool calls that have all completed. If so, it injects a human message with the image
|
||||
details so the LLM can see and analyze the images.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
runtime: Runtime context (unused but required by interface)
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
return self._inject_image_message(state)
|
||||
|
||||
@override
|
||||
async def abefore_model(self, state: ViewImageMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Inject image details message before LLM call if view_image tools have completed (async version).
|
||||
|
||||
This runs before each LLM call, checking if the previous turn included view_image
|
||||
tool calls that have all completed. If so, it injects a human message with the image
|
||||
details so the LLM can see and analyze the images.
|
||||
|
||||
Args:
|
||||
state: Current state
|
||||
runtime: Runtime context (unused but required by interface)
|
||||
|
||||
Returns:
|
||||
State update with additional human message, or None if no update needed
|
||||
"""
|
||||
return self._inject_image_message(state)
|
||||
Reference in New Issue
Block a user