Initial commit: hardened DeerFlow factory

Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection
hardening:

- New deerflow.security package: content_delimiter, html_cleaner,
  sanitizer (8 layers — invisible chars, control chars, symbols, NFC,
  PUA, tag chars, horizontal whitespace collapse with newline/tab
  preservation, length cap)
- New deerflow.community.searx package: web_search, web_fetch,
  image_search backed by a private SearX instance, every external
  string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>>
  delimiters
- All native community web providers (ddg_search, tavily, exa,
  firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail
  stubs that raise NativeWebToolDisabledError at import time, so a
  misconfigured tool.use path fails loud rather than silently falling
  back to unsanitized output
- Native client back-doors (jina_client.py, infoquest_client.py)
  stubbed too
- Native-tool tests quarantined under tests/_disabled_native/
  (collect_ignore_glob via local conftest.py)
- Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve
  newlines and tabs so list/table structure survives
- Hardened runtime config.yaml references only the searx-backed tools
- Factory overlay (backend/) kept in sync with deer-flow tree as a
  reference / source

See HARDENING.md for the full audit trail and verification steps.
This commit is contained in:
2026-04-12 14:23:57 +02:00
commit 6de0bf9f5b
889 changed files with 173052 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from .factory import create_chat_model
__all__ = ["create_chat_model"]

View File

@@ -0,0 +1,348 @@
"""Custom Claude provider with OAuth Bearer auth, prompt caching, and smart thinking.
Supports two authentication modes:
1. Standard API key (x-api-key header) — default ChatAnthropic behavior
2. Claude Code OAuth token (Authorization: Bearer header)
- Detected by sk-ant-oat prefix
- Requires anthropic-beta: oauth-2025-04-20,claude-code-20250219
- Requires billing header in system prompt for all OAuth requests
Auto-loads credentials from explicit runtime handoff:
- $ANTHROPIC_API_KEY environment variable
- $CLAUDE_CODE_OAUTH_TOKEN or $ANTHROPIC_AUTH_TOKEN
- $CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR
- $CLAUDE_CODE_CREDENTIALS_PATH
- ~/.claude/.credentials.json
"""
import hashlib
import json
import logging
import os
import socket
import time
import uuid
from typing import Any
import anthropic
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import BaseMessage
from pydantic import PrivateAttr
logger = logging.getLogger(__name__)
MAX_RETRIES = 3
THINKING_BUDGET_RATIO = 0.8
# Billing header required by Anthropic API for OAuth token access.
# Must be the first system prompt block. Format mirrors Claude Code CLI.
# Override with ANTHROPIC_BILLING_HEADER env var if the hardcoded version drifts.
_DEFAULT_BILLING_HEADER = "x-anthropic-billing-header: cc_version=2.1.85.351; cc_entrypoint=cli; cch=6c6d5;"
OAUTH_BILLING_HEADER = os.environ.get("ANTHROPIC_BILLING_HEADER", _DEFAULT_BILLING_HEADER)
class ClaudeChatModel(ChatAnthropic):
"""ChatAnthropic with OAuth Bearer auth, prompt caching, and smart thinking.
Config example:
- name: claude-sonnet-4.6
use: deerflow.models.claude_provider:ClaudeChatModel
model: claude-sonnet-4-6
max_tokens: 16384
enable_prompt_caching: true
"""
# Custom fields
enable_prompt_caching: bool = True
prompt_cache_size: int = 3
auto_thinking_budget: bool = True
retry_max_attempts: int = MAX_RETRIES
_is_oauth: bool = PrivateAttr(default=False)
_oauth_access_token: str = PrivateAttr(default="")
model_config = {"arbitrary_types_allowed": True}
def _validate_retry_config(self) -> None:
if self.retry_max_attempts < 1:
raise ValueError("retry_max_attempts must be >= 1")
def model_post_init(self, __context: Any) -> None:
"""Auto-load credentials and configure OAuth if needed."""
from pydantic import SecretStr
from deerflow.models.credential_loader import (
OAUTH_ANTHROPIC_BETAS,
is_oauth_token,
load_claude_code_credential,
)
self._validate_retry_config()
# Extract actual key value (SecretStr.str() returns '**********')
current_key = ""
if self.anthropic_api_key:
if hasattr(self.anthropic_api_key, "get_secret_value"):
current_key = self.anthropic_api_key.get_secret_value()
else:
current_key = str(self.anthropic_api_key)
# Try the explicit Claude Code OAuth handoff sources if no valid key.
if not current_key or current_key in ("your-anthropic-api-key",):
cred = load_claude_code_credential()
if cred:
current_key = cred.access_token
logger.info(f"Using Claude Code CLI credential (source: {cred.source})")
else:
logger.warning("No Anthropic API key or explicit Claude Code OAuth credential found.")
# Detect OAuth token and configure Bearer auth
if is_oauth_token(current_key):
self._is_oauth = True
self._oauth_access_token = current_key
# Set the token as api_key temporarily (will be swapped to auth_token on client)
self.anthropic_api_key = SecretStr(current_key)
# Add required beta headers for OAuth
self.default_headers = {
**(self.default_headers or {}),
"anthropic-beta": OAUTH_ANTHROPIC_BETAS,
}
# OAuth tokens have a limit of 4 cache_control blocks — disable prompt caching
self.enable_prompt_caching = False
logger.info("OAuth token detected — will use Authorization: Bearer header")
else:
if current_key:
self.anthropic_api_key = SecretStr(current_key)
# Ensure api_key is SecretStr
if isinstance(self.anthropic_api_key, str):
self.anthropic_api_key = SecretStr(self.anthropic_api_key)
super().model_post_init(__context)
# Patch clients immediately after creation for OAuth Bearer auth.
# This must happen after super() because clients are lazily created.
if self._is_oauth:
self._patch_client_oauth(self._client)
self._patch_client_oauth(self._async_client)
def _patch_client_oauth(self, client: Any) -> None:
"""Swap api_key → auth_token on an Anthropic SDK client for OAuth Bearer auth."""
if hasattr(client, "api_key") and hasattr(client, "auth_token"):
client.api_key = None
client.auth_token = self._oauth_access_token
def _get_request_payload(
self,
input_: Any,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
"""Override to inject prompt caching, thinking budget, and OAuth billing."""
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
if self._is_oauth:
self._apply_oauth_billing(payload)
if self.enable_prompt_caching:
self._apply_prompt_caching(payload)
if self.auto_thinking_budget:
self._apply_thinking_budget(payload)
return payload
def _apply_oauth_billing(self, payload: dict) -> None:
"""Inject the billing header block required for all OAuth requests.
The billing block is always placed first in the system list, removing any
existing occurrence to avoid duplication or out-of-order positioning.
"""
billing_block = {"type": "text", "text": OAUTH_BILLING_HEADER}
system = payload.get("system")
if isinstance(system, list):
# Remove any existing billing blocks, then insert a single one at index 0.
filtered = [b for b in system if not (isinstance(b, dict) and OAUTH_BILLING_HEADER in b.get("text", ""))]
payload["system"] = [billing_block] + filtered
elif isinstance(system, str):
if OAUTH_BILLING_HEADER in system:
payload["system"] = [billing_block]
else:
payload["system"] = [billing_block, {"type": "text", "text": system}]
else:
payload["system"] = [billing_block]
# Add metadata.user_id required by the API for OAuth billing validation
if not isinstance(payload.get("metadata"), dict):
payload["metadata"] = {}
if "user_id" not in payload["metadata"]:
# Generate a stable device_id from the machine's hostname
hostname = socket.gethostname()
device_id = hashlib.sha256(f"deerflow-{hostname}".encode()).hexdigest()
session_id = str(uuid.uuid4())
payload["metadata"]["user_id"] = json.dumps(
{
"device_id": device_id,
"account_uuid": "deerflow",
"session_id": session_id,
}
)
def _apply_prompt_caching(self, payload: dict) -> None:
"""Apply ephemeral cache_control to system and recent messages."""
# Cache system messages
system = payload.get("system")
if system and isinstance(system, list):
for block in system:
if isinstance(block, dict) and block.get("type") == "text":
block["cache_control"] = {"type": "ephemeral"}
elif system and isinstance(system, str):
payload["system"] = [
{
"type": "text",
"text": system,
"cache_control": {"type": "ephemeral"},
}
]
# Cache recent messages
messages = payload.get("messages", [])
cache_start = max(0, len(messages) - self.prompt_cache_size)
for i in range(cache_start, len(messages)):
msg = messages[i]
if not isinstance(msg, dict):
continue
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
block["cache_control"] = {"type": "ephemeral"}
elif isinstance(content, str) and content:
msg["content"] = [
{
"type": "text",
"text": content,
"cache_control": {"type": "ephemeral"},
}
]
# Cache the last tool definition
tools = payload.get("tools", [])
if tools and isinstance(tools[-1], dict):
tools[-1]["cache_control"] = {"type": "ephemeral"}
def _apply_thinking_budget(self, payload: dict) -> None:
"""Auto-allocate thinking budget (80% of max_tokens)."""
thinking = payload.get("thinking")
if not thinking or not isinstance(thinking, dict):
return
if thinking.get("type") != "enabled":
return
if thinking.get("budget_tokens"):
return
max_tokens = payload.get("max_tokens", 8192)
thinking["budget_tokens"] = int(max_tokens * THINKING_BUDGET_RATIO)
@staticmethod
def _strip_cache_control(payload: dict) -> None:
"""Remove cache_control markers before OAuth requests reach Anthropic."""
for section in ("system", "messages"):
items = payload.get(section)
if not isinstance(items, list):
continue
for item in items:
if not isinstance(item, dict):
continue
item.pop("cache_control", None)
content = item.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
block.pop("cache_control", None)
tools = payload.get("tools")
if isinstance(tools, list):
for tool in tools:
if isinstance(tool, dict):
tool.pop("cache_control", None)
def _create(self, payload: dict) -> Any:
if self._is_oauth:
self._strip_cache_control(payload)
return super()._create(payload)
async def _acreate(self, payload: dict) -> Any:
if self._is_oauth:
self._strip_cache_control(payload)
return await super()._acreate(payload)
def _generate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
"""Override with OAuth patching and retry logic."""
if self._is_oauth:
self._patch_client_oauth(self._client)
last_error = None
for attempt in range(1, self.retry_max_attempts + 1):
try:
return super()._generate(messages, stop=stop, **kwargs)
except anthropic.RateLimitError as e:
last_error = e
if attempt >= self.retry_max_attempts:
raise
wait_ms = self._calc_backoff_ms(attempt, e)
logger.warning(f"Rate limited, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
time.sleep(wait_ms / 1000)
except anthropic.InternalServerError as e:
last_error = e
if attempt >= self.retry_max_attempts:
raise
wait_ms = self._calc_backoff_ms(attempt, e)
logger.warning(f"Server error, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
time.sleep(wait_ms / 1000)
raise last_error
async def _agenerate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
"""Async override with OAuth patching and retry logic."""
import asyncio
if self._is_oauth:
self._patch_client_oauth(self._async_client)
last_error = None
for attempt in range(1, self.retry_max_attempts + 1):
try:
return await super()._agenerate(messages, stop=stop, **kwargs)
except anthropic.RateLimitError as e:
last_error = e
if attempt >= self.retry_max_attempts:
raise
wait_ms = self._calc_backoff_ms(attempt, e)
logger.warning(f"Rate limited, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
await asyncio.sleep(wait_ms / 1000)
except anthropic.InternalServerError as e:
last_error = e
if attempt >= self.retry_max_attempts:
raise
wait_ms = self._calc_backoff_ms(attempt, e)
logger.warning(f"Server error, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
await asyncio.sleep(wait_ms / 1000)
raise last_error
@staticmethod
def _calc_backoff_ms(attempt: int, error: Exception) -> int:
"""Exponential backoff with a fixed 20% buffer."""
backoff_ms = 2000 * (1 << (attempt - 1))
jitter_ms = int(backoff_ms * 0.2)
total_ms = backoff_ms + jitter_ms
if hasattr(error, "response") and error.response is not None:
retry_after = error.response.headers.get("Retry-After")
if retry_after:
try:
total_ms = int(retry_after) * 1000
except (ValueError, TypeError):
pass
return total_ms

View File

@@ -0,0 +1,219 @@
"""Auto-load credentials from Claude Code CLI and Codex CLI.
Implements two credential strategies:
1. Claude Code OAuth token from explicit env vars or an exported credentials file
- Uses Authorization: Bearer header (NOT x-api-key)
- Requires anthropic-beta: oauth-2025-04-20,claude-code-20250219
- Supports $CLAUDE_CODE_OAUTH_TOKEN, $CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR, and $ANTHROPIC_AUTH_TOKEN
- Override path with $CLAUDE_CODE_CREDENTIALS_PATH
2. Codex CLI token from ~/.codex/auth.json
- Uses chatgpt.com/backend-api/codex/responses endpoint
- Supports both legacy top-level tokens and current nested tokens shape
- Override path with $CODEX_AUTH_PATH
"""
import json
import logging
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# Required beta headers for Claude Code OAuth tokens
OAUTH_ANTHROPIC_BETAS = "oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14"
def is_oauth_token(token: str) -> bool:
"""Check if a token is a Claude Code OAuth token (not a standard API key)."""
return isinstance(token, str) and "sk-ant-oat" in token
@dataclass
class ClaudeCodeCredential:
"""Claude Code CLI OAuth credential."""
access_token: str
refresh_token: str = ""
expires_at: int = 0
source: str = ""
@property
def is_expired(self) -> bool:
if self.expires_at <= 0:
return False
return time.time() * 1000 > self.expires_at - 60_000 # 1 min buffer
@dataclass
class CodexCliCredential:
"""Codex CLI credential."""
access_token: str
account_id: str = ""
source: str = ""
def _resolve_credential_path(env_var: str, default_relative_path: str) -> Path:
configured_path = os.getenv(env_var)
if configured_path:
return Path(configured_path).expanduser()
return _home_dir() / default_relative_path
def _home_dir() -> Path:
home = os.getenv("HOME")
if home:
return Path(home).expanduser()
return Path.home()
def _load_json_file(path: Path, label: str) -> dict[str, Any] | None:
if not path.exists():
logger.debug(f"{label} not found: {path}")
return None
if path.is_dir():
logger.warning(f"{label} path is a directory, expected a file: {path}")
return None
try:
return json.loads(path.read_text())
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"Failed to read {label}: {e}")
return None
def _read_secret_from_file_descriptor(env_var: str) -> str | None:
fd_value = os.getenv(env_var)
if not fd_value:
return None
try:
fd = int(fd_value)
except ValueError:
logger.warning(f"{env_var} must be an integer file descriptor, got: {fd_value}")
return None
try:
secret = os.read(fd, 1024 * 1024).decode().strip()
except OSError as e:
logger.warning(f"Failed to read {env_var}: {e}")
return None
return secret or None
def _credential_from_direct_token(access_token: str, source: str) -> ClaudeCodeCredential | None:
token = access_token.strip()
if not token:
return None
return ClaudeCodeCredential(access_token=token, source=source)
def _iter_claude_code_credential_paths() -> list[Path]:
paths: list[Path] = []
override_path = os.getenv("CLAUDE_CODE_CREDENTIALS_PATH")
if override_path:
paths.append(Path(override_path).expanduser())
default_path = _home_dir() / ".claude/.credentials.json"
if not paths or paths[-1] != default_path:
paths.append(default_path)
return paths
def _extract_claude_code_credential(data: dict[str, Any], source: str) -> ClaudeCodeCredential | None:
oauth = data.get("claudeAiOauth", {})
access_token = oauth.get("accessToken", "")
if not access_token:
logger.debug("Claude Code credentials container exists but no accessToken found")
return None
cred = ClaudeCodeCredential(
access_token=access_token,
refresh_token=oauth.get("refreshToken", ""),
expires_at=oauth.get("expiresAt", 0),
source=source,
)
if cred.is_expired:
logger.warning("Claude Code OAuth token is expired. Run 'claude' to refresh.")
return None
return cred
def load_claude_code_credential() -> ClaudeCodeCredential | None:
"""Load OAuth credential from explicit Claude Code handoff sources.
Lookup order:
1. $CLAUDE_CODE_OAUTH_TOKEN or $ANTHROPIC_AUTH_TOKEN
2. $CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR
3. $CLAUDE_CODE_CREDENTIALS_PATH
4. ~/.claude/.credentials.json
Exported credentials files contain:
{
"claudeAiOauth": {
"accessToken": "sk-ant-oat01-...",
"refreshToken": "sk-ant-ort01-...",
"expiresAt": 1773430695128,
"scopes": ["user:inference", ...],
...
}
}
"""
direct_token = os.getenv("CLAUDE_CODE_OAUTH_TOKEN") or os.getenv("ANTHROPIC_AUTH_TOKEN")
if direct_token:
cred = _credential_from_direct_token(direct_token, "claude-cli-env")
if cred:
logger.info("Loaded Claude Code OAuth credential from environment")
return cred
fd_token = _read_secret_from_file_descriptor("CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR")
if fd_token:
cred = _credential_from_direct_token(fd_token, "claude-cli-fd")
if cred:
logger.info("Loaded Claude Code OAuth credential from file descriptor")
return cred
override_path = os.getenv("CLAUDE_CODE_CREDENTIALS_PATH")
override_path_obj = Path(override_path).expanduser() if override_path else None
for cred_path in _iter_claude_code_credential_paths():
data = _load_json_file(cred_path, "Claude Code credentials")
if data is None:
continue
cred = _extract_claude_code_credential(data, "claude-cli-file")
if cred:
source_label = "override path" if override_path_obj is not None and cred_path == override_path_obj else "plaintext file"
logger.info(f"Loaded Claude Code OAuth credential from {source_label} (expires_at={cred.expires_at})")
return cred
return None
def load_codex_cli_credential() -> CodexCliCredential | None:
"""Load credential from Codex CLI (~/.codex/auth.json)."""
cred_path = _resolve_credential_path("CODEX_AUTH_PATH", ".codex/auth.json")
data = _load_json_file(cred_path, "Codex CLI credentials")
if data is None:
return None
tokens = data.get("tokens", {})
if not isinstance(tokens, dict):
tokens = {}
access_token = data.get("access_token") or data.get("token") or tokens.get("access_token", "")
account_id = data.get("account_id") or tokens.get("account_id", "")
if not access_token:
logger.debug("Codex CLI credentials file exists but no token found")
return None
logger.info("Loaded Codex CLI credential")
return CodexCliCredential(
access_token=access_token,
account_id=account_id,
source="codex-cli",
)

View File

@@ -0,0 +1,123 @@
import logging
from langchain.chat_models import BaseChatModel
from deerflow.config import get_app_config
from deerflow.reflection import resolve_class
from deerflow.tracing import build_tracing_callbacks
logger = logging.getLogger(__name__)
def _deep_merge_dicts(base: dict | None, override: dict) -> dict:
"""Recursively merge two dictionaries without mutating the inputs."""
merged = dict(base or {})
for key, value in override.items():
if isinstance(value, dict) and isinstance(merged.get(key), dict):
merged[key] = _deep_merge_dicts(merged[key], value)
else:
merged[key] = value
return merged
def _vllm_disable_chat_template_kwargs(chat_template_kwargs: dict) -> dict:
"""Build the disable payload for vLLM/Qwen chat template kwargs."""
disable_kwargs: dict[str, bool] = {}
if "thinking" in chat_template_kwargs:
disable_kwargs["thinking"] = False
if "enable_thinking" in chat_template_kwargs:
disable_kwargs["enable_thinking"] = False
return disable_kwargs
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
"""Create a chat model instance from the config.
Args:
name: The name of the model to create. If None, the first model in the config will be used.
Returns:
A chat model instance.
"""
config = get_app_config()
if name is None:
name = config.models[0].name
model_config = config.get_model_config(name)
if model_config is None:
raise ValueError(f"Model {name} not found in config") from None
model_class = resolve_class(model_config.use, BaseChatModel)
model_settings_from_config = model_config.model_dump(
exclude_none=True,
exclude={
"use",
"name",
"display_name",
"description",
"supports_thinking",
"supports_reasoning_effort",
"when_thinking_enabled",
"when_thinking_disabled",
"thinking",
"supports_vision",
},
)
# Compute effective when_thinking_enabled by merging in the `thinking` shortcut field.
# The `thinking` shortcut is equivalent to setting when_thinking_enabled["thinking"].
has_thinking_settings = (model_config.when_thinking_enabled is not None) or (model_config.thinking is not None)
effective_wte: dict = dict(model_config.when_thinking_enabled) if model_config.when_thinking_enabled else {}
if model_config.thinking is not None:
merged_thinking = {**(effective_wte.get("thinking") or {}), **model_config.thinking}
effective_wte = {**effective_wte, "thinking": merged_thinking}
if thinking_enabled and has_thinking_settings:
if not model_config.supports_thinking:
raise ValueError(f"Model {name} does not support thinking. Set `supports_thinking` to true in the `config.yaml` to enable thinking.") from None
if effective_wte:
model_settings_from_config.update(effective_wte)
if not thinking_enabled:
if model_config.when_thinking_disabled is not None:
# User-provided disable settings take full precedence
model_settings_from_config.update(model_config.when_thinking_disabled)
elif has_thinking_settings and effective_wte.get("extra_body", {}).get("thinking", {}).get("type"):
# OpenAI-compatible gateway: thinking is nested under extra_body
model_settings_from_config["extra_body"] = _deep_merge_dicts(
model_settings_from_config.get("extra_body"),
{"thinking": {"type": "disabled"}},
)
model_settings_from_config["reasoning_effort"] = "minimal"
elif has_thinking_settings and (disable_chat_template_kwargs := _vllm_disable_chat_template_kwargs(effective_wte.get("extra_body", {}).get("chat_template_kwargs") or {})):
# vLLM uses chat template kwargs to switch thinking on/off.
model_settings_from_config["extra_body"] = _deep_merge_dicts(
model_settings_from_config.get("extra_body"),
{"chat_template_kwargs": disable_chat_template_kwargs},
)
elif has_thinking_settings and effective_wte.get("thinking", {}).get("type"):
# Native langchain_anthropic: thinking is a direct constructor parameter
model_settings_from_config["thinking"] = {"type": "disabled"}
if not model_config.supports_reasoning_effort:
kwargs.pop("reasoning_effort", None)
model_settings_from_config.pop("reasoning_effort", None)
# For Codex Responses API models: map thinking mode to reasoning_effort
from deerflow.models.openai_codex_provider import CodexChatModel
if issubclass(model_class, CodexChatModel):
# The ChatGPT Codex endpoint currently rejects max_tokens/max_output_tokens.
model_settings_from_config.pop("max_tokens", None)
# Use explicit reasoning_effort from frontend if provided (low/medium/high)
explicit_effort = kwargs.pop("reasoning_effort", None)
if not thinking_enabled:
model_settings_from_config["reasoning_effort"] = "none"
elif explicit_effort and explicit_effort in ("low", "medium", "high", "xhigh"):
model_settings_from_config["reasoning_effort"] = explicit_effort
elif "reasoning_effort" not in model_settings_from_config:
model_settings_from_config["reasoning_effort"] = "medium"
model_instance = model_class(**{**model_settings_from_config, **kwargs})
callbacks = build_tracing_callbacks()
if callbacks:
existing_callbacks = model_instance.callbacks or []
model_instance.callbacks = [*existing_callbacks, *callbacks]
logger.debug(f"Tracing attached to model '{name}' with providers={len(callbacks)}")
return model_instance

View File

@@ -0,0 +1,430 @@
"""Custom OpenAI Codex provider using ChatGPT Codex Responses API.
Uses Codex CLI OAuth tokens with chatgpt.com/backend-api/codex/responses endpoint.
This is the same endpoint that the Codex CLI uses internally.
Supports:
- Auto-load credentials from ~/.codex/auth.json
- Responses API format (not Chat Completions)
- Tool calling
- Streaming (required by the endpoint)
- Retry with exponential backoff
"""
import json
import logging
import time
from typing import Any
import httpx
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from deerflow.models.credential_loader import CodexCliCredential, load_codex_cli_credential
logger = logging.getLogger(__name__)
CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
MAX_RETRIES = 3
class CodexChatModel(BaseChatModel):
"""LangChain chat model using ChatGPT Codex Responses API.
Config example:
- name: gpt-5.4
use: deerflow.models.openai_codex_provider:CodexChatModel
model: gpt-5.4
reasoning_effort: medium
"""
model: str = "gpt-5.4"
reasoning_effort: str = "medium"
retry_max_attempts: int = MAX_RETRIES
_access_token: str = ""
_account_id: str = ""
model_config = {"arbitrary_types_allowed": True}
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def _llm_type(self) -> str:
return "codex-responses"
def _validate_retry_config(self) -> None:
if self.retry_max_attempts < 1:
raise ValueError("retry_max_attempts must be >= 1")
def model_post_init(self, __context: Any) -> None:
"""Auto-load Codex CLI credentials."""
self._validate_retry_config()
cred = self._load_codex_auth()
if cred:
self._access_token = cred.access_token
self._account_id = cred.account_id
logger.info(f"Using Codex CLI credential (account: {self._account_id[:8]}...)")
else:
raise ValueError("Codex CLI credential not found. Expected ~/.codex/auth.json or CODEX_AUTH_PATH.")
super().model_post_init(__context)
def _load_codex_auth(self) -> CodexCliCredential | None:
"""Load access_token and account_id from Codex CLI auth."""
return load_codex_cli_credential()
@classmethod
def _normalize_content(cls, content: Any) -> str:
"""Flatten LangChain content blocks into plain text for Codex."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [cls._normalize_content(item) for item in content]
return "\n".join(part for part in parts if part)
if isinstance(content, dict):
for key in ("text", "output"):
value = content.get(key)
if isinstance(value, str):
return value
nested_content = content.get("content")
if nested_content is not None:
return cls._normalize_content(nested_content)
try:
return json.dumps(content, ensure_ascii=False)
except TypeError:
return str(content)
try:
return json.dumps(content, ensure_ascii=False)
except TypeError:
return str(content)
def _convert_messages(self, messages: list[BaseMessage]) -> tuple[str, list[dict]]:
"""Convert LangChain messages to Responses API format.
Returns (instructions, input_items).
"""
instructions_parts: list[str] = []
input_items = []
for msg in messages:
if isinstance(msg, SystemMessage):
content = self._normalize_content(msg.content)
if content:
instructions_parts.append(content)
elif isinstance(msg, HumanMessage):
content = self._normalize_content(msg.content)
input_items.append({"role": "user", "content": content})
elif isinstance(msg, AIMessage):
if msg.content:
content = self._normalize_content(msg.content)
input_items.append({"role": "assistant", "content": content})
if msg.tool_calls:
for tc in msg.tool_calls:
input_items.append(
{
"type": "function_call",
"name": tc["name"],
"arguments": json.dumps(tc["args"]) if isinstance(tc["args"], dict) else tc["args"],
"call_id": tc["id"],
}
)
elif isinstance(msg, ToolMessage):
input_items.append(
{
"type": "function_call_output",
"call_id": msg.tool_call_id,
"output": self._normalize_content(msg.content),
}
)
instructions = "\n\n".join(instructions_parts) or "You are a helpful assistant."
return instructions, input_items
def _convert_tools(self, tools: list[dict]) -> list[dict]:
"""Convert LangChain tool format to Responses API format."""
responses_tools = []
for tool in tools:
if tool.get("type") == "function" and "function" in tool:
fn = tool["function"]
responses_tools.append(
{
"type": "function",
"name": fn["name"],
"description": fn.get("description", ""),
"parameters": fn.get("parameters", {}),
}
)
elif "name" in tool:
responses_tools.append(
{
"type": "function",
"name": tool["name"],
"description": tool.get("description", ""),
"parameters": tool.get("parameters", {}),
}
)
return responses_tools
def _call_codex_api(self, messages: list[BaseMessage], tools: list[dict] | None = None) -> dict:
"""Call the Codex Responses API and return the completed response."""
instructions, input_items = self._convert_messages(messages)
payload = {
"model": self.model,
"instructions": instructions,
"input": input_items,
"store": False,
"stream": True,
"reasoning": {"effort": self.reasoning_effort, "summary": "detailed"} if self.reasoning_effort != "none" else {"effort": "none"},
}
if tools:
payload["tools"] = self._convert_tools(tools)
headers = {
"Authorization": f"Bearer {self._access_token}",
"ChatGPT-Account-ID": self._account_id,
"Content-Type": "application/json",
"Accept": "text/event-stream",
"originator": "codex_cli_rs",
}
last_error = None
for attempt in range(1, self.retry_max_attempts + 1):
try:
return self._stream_response(headers, payload)
except httpx.HTTPStatusError as e:
last_error = e
if e.response.status_code in (429, 500, 529):
if attempt >= self.retry_max_attempts:
raise
wait_ms = 2000 * (1 << (attempt - 1))
logger.warning(f"Codex API error {e.response.status_code}, retrying {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
time.sleep(wait_ms / 1000)
else:
raise
except Exception:
raise
raise last_error
def _stream_response(self, headers: dict, payload: dict) -> dict:
"""Stream SSE from Codex API and collect the final response."""
completed_response = None
streamed_output_items: dict[int, dict[str, Any]] = {}
with httpx.Client(timeout=300) as client:
with client.stream("POST", f"{CODEX_BASE_URL}/responses", headers=headers, json=payload) as resp:
resp.raise_for_status()
for line in resp.iter_lines():
data = self._parse_sse_data_line(line)
if not data:
continue
event_type = data.get("type")
if event_type == "response.output_item.done":
output_index = data.get("output_index")
output_item = data.get("item")
if isinstance(output_index, int) and isinstance(output_item, dict):
streamed_output_items[output_index] = output_item
elif event_type == "response.completed":
completed_response = data["response"]
if not completed_response:
raise RuntimeError("Codex API stream ended without response.completed event")
# ChatGPT Codex can emit the final assistant content only in stream events.
# When response.completed arrives, response.output may still be empty.
if streamed_output_items:
merged_output = []
response_output = completed_response.get("output")
if isinstance(response_output, list):
merged_output = list(response_output)
max_index = max(max(streamed_output_items), len(merged_output) - 1)
if max_index >= 0 and len(merged_output) <= max_index:
merged_output.extend([None] * (max_index + 1 - len(merged_output)))
for output_index, output_item in streamed_output_items.items():
existing_item = merged_output[output_index]
if not isinstance(existing_item, dict):
merged_output[output_index] = output_item
completed_response = dict(completed_response)
completed_response["output"] = [item for item in merged_output if isinstance(item, dict)]
return completed_response
@staticmethod
def _parse_sse_data_line(line: str) -> dict[str, Any] | None:
"""Parse a data line from the SSE stream, skipping terminal markers."""
if not line.startswith("data:"):
return None
raw_data = line[5:].strip()
if not raw_data or raw_data == "[DONE]":
return None
try:
data = json.loads(raw_data)
except json.JSONDecodeError:
logger.debug(f"Skipping non-JSON Codex SSE frame: {raw_data}")
return None
return data if isinstance(data, dict) else None
def _parse_tool_call_arguments(self, output_item: dict[str, Any]) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
"""Parse function-call arguments, surfacing malformed payloads safely."""
raw_arguments = output_item.get("arguments", "{}")
if isinstance(raw_arguments, dict):
return raw_arguments, None
normalized_arguments = raw_arguments or "{}"
try:
parsed_arguments = json.loads(normalized_arguments)
except (TypeError, json.JSONDecodeError) as exc:
return None, {
"type": "invalid_tool_call",
"name": output_item.get("name"),
"args": str(raw_arguments),
"id": output_item.get("call_id"),
"error": f"Failed to parse tool arguments: {exc}",
}
if not isinstance(parsed_arguments, dict):
return None, {
"type": "invalid_tool_call",
"name": output_item.get("name"),
"args": str(raw_arguments),
"id": output_item.get("call_id"),
"error": "Tool arguments must decode to a JSON object.",
}
return parsed_arguments, None
def _parse_response(self, response: dict) -> ChatResult:
"""Parse Codex Responses API response into LangChain ChatResult."""
content = ""
tool_calls = []
invalid_tool_calls = []
reasoning_content = ""
for output_item in response.get("output", []):
if output_item.get("type") == "reasoning":
# Extract reasoning summary text
for summary_item in output_item.get("summary", []):
if isinstance(summary_item, dict) and summary_item.get("type") == "summary_text":
reasoning_content += summary_item.get("text", "")
elif isinstance(summary_item, str):
reasoning_content += summary_item
elif output_item.get("type") == "message":
for part in output_item.get("content", []):
if part.get("type") == "output_text":
content += part.get("text", "")
elif output_item.get("type") == "function_call":
parsed_arguments, invalid_tool_call = self._parse_tool_call_arguments(output_item)
if invalid_tool_call:
invalid_tool_calls.append(invalid_tool_call)
continue
tool_calls.append(
{
"name": output_item["name"],
"args": parsed_arguments or {},
"id": output_item.get("call_id", ""),
"type": "tool_call",
}
)
usage = response.get("usage", {})
additional_kwargs = {}
if reasoning_content:
additional_kwargs["reasoning_content"] = reasoning_content
message = AIMessage(
content=content,
tool_calls=tool_calls if tool_calls else [],
invalid_tool_calls=invalid_tool_calls,
additional_kwargs=additional_kwargs,
response_metadata={
"model": response.get("model", self.model),
"usage": usage,
},
)
return ChatResult(
generations=[ChatGeneration(message=message)],
llm_output={
"token_usage": {
"prompt_tokens": usage.get("input_tokens", 0),
"completion_tokens": usage.get("output_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
},
"model_name": response.get("model", self.model),
},
)
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response using Codex Responses API."""
tools = kwargs.get("tools", None)
response = self._call_codex_api(messages, tools=tools)
return self._parse_response(response)
def bind_tools(self, tools: list, **kwargs: Any) -> Any:
"""Bind tools for function calling."""
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function
formatted_tools = []
for tool in tools:
if isinstance(tool, BaseTool):
try:
fn = convert_to_openai_function(tool)
formatted_tools.append(
{
"type": "function",
"name": fn["name"],
"description": fn.get("description", ""),
"parameters": fn.get("parameters", {}),
}
)
except Exception:
formatted_tools.append(
{
"type": "function",
"name": tool.name,
"description": tool.description,
"parameters": {"type": "object", "properties": {}},
}
)
elif isinstance(tool, dict):
if "function" in tool:
fn = tool["function"]
formatted_tools.append(
{
"type": "function",
"name": fn["name"],
"description": fn.get("description", ""),
"parameters": fn.get("parameters", {}),
}
)
else:
formatted_tools.append(tool)
return RunnableBinding(bound=self, kwargs={"tools": formatted_tools}, **kwargs)

View File

@@ -0,0 +1,73 @@
"""Patched ChatDeepSeek that preserves reasoning_content in multi-turn conversations.
This module provides a patched version of ChatDeepSeek that properly handles
reasoning_content when sending messages back to the API. The original implementation
stores reasoning_content in additional_kwargs but doesn't include it when making
subsequent API calls, which causes errors with APIs that require reasoning_content
on all assistant messages when thinking mode is enabled.
"""
from typing import Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage
from langchain_deepseek import ChatDeepSeek
class PatchedChatDeepSeek(ChatDeepSeek):
"""ChatDeepSeek with proper reasoning_content preservation.
When using thinking/reasoning enabled models, the API expects reasoning_content
to be present on ALL assistant messages in multi-turn conversations. This patched
version ensures reasoning_content from additional_kwargs is included in the
request payload.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@property
def lc_secrets(self) -> dict[str, str]:
return {"api_key": "DEEPSEEK_API_KEY", "openai_api_key": "DEEPSEEK_API_KEY"}
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
"""Get request payload with reasoning_content preserved.
Overrides the parent method to inject reasoning_content from
additional_kwargs into assistant messages in the payload.
"""
# Get the original messages before conversion
original_messages = self._convert_input(input_).to_messages()
# Call parent to get the base payload
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
# Match payload messages with original messages to restore reasoning_content
payload_messages = payload.get("messages", [])
# The payload messages and original messages should be in the same order
# Iterate through both and match by position
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
reasoning_content = orig_msg.additional_kwargs.get("reasoning_content")
if reasoning_content is not None:
payload_msg["reasoning_content"] = reasoning_content
else:
# Fallback: match by counting assistant messages
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
for (idx, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
reasoning_content = ai_msg.additional_kwargs.get("reasoning_content")
if reasoning_content is not None:
payload_messages[idx]["reasoning_content"] = reasoning_content
return payload

View File

@@ -0,0 +1,220 @@
"""Patched ChatOpenAI adapter for MiniMax reasoning output.
MiniMax's OpenAI-compatible chat completions API can return structured
``reasoning_details`` when ``extra_body.reasoning_split=true`` is enabled.
``langchain_openai.ChatOpenAI`` currently ignores that field, so DeerFlow's
frontend never receives reasoning content in the shape it expects.
This adapter preserves ``reasoning_split`` in the request payload and maps the
provider-specific reasoning field into ``additional_kwargs.reasoning_content``,
which DeerFlow already understands.
"""
from __future__ import annotations
import re
from collections.abc import Mapping
from typing import Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (
_convert_delta_to_message_chunk,
_create_usage_metadata,
)
_THINK_TAG_RE = re.compile(r"<think>\s*(.*?)\s*</think>", re.DOTALL)
def _extract_reasoning_text(
reasoning_details: Any,
*,
strip_parts: bool = True,
) -> str | None:
if not isinstance(reasoning_details, list):
return None
parts: list[str] = []
for item in reasoning_details:
if not isinstance(item, Mapping):
continue
text = item.get("text")
if isinstance(text, str):
normalized = text.strip() if strip_parts else text
if normalized.strip():
parts.append(normalized)
return "\n\n".join(parts) if parts else None
def _strip_inline_think_tags(content: str) -> tuple[str, str | None]:
reasoning_parts: list[str] = []
def _replace(match: re.Match[str]) -> str:
reasoning = match.group(1).strip()
if reasoning:
reasoning_parts.append(reasoning)
return ""
cleaned = _THINK_TAG_RE.sub(_replace, content).strip()
reasoning = "\n\n".join(reasoning_parts) if reasoning_parts else None
return cleaned, reasoning
def _merge_reasoning(*values: str | None) -> str | None:
merged: list[str] = []
for value in values:
if not value:
continue
normalized = value.strip()
if normalized and normalized not in merged:
merged.append(normalized)
return "\n\n".join(merged) if merged else None
def _with_reasoning_content(
message: AIMessage | AIMessageChunk,
reasoning: str | None,
*,
preserve_whitespace: bool = False,
):
if not reasoning:
return message
additional_kwargs = dict(message.additional_kwargs)
if preserve_whitespace:
existing = additional_kwargs.get("reasoning_content")
additional_kwargs["reasoning_content"] = f"{existing}{reasoning}" if isinstance(existing, str) else reasoning
else:
additional_kwargs["reasoning_content"] = _merge_reasoning(
additional_kwargs.get("reasoning_content"),
reasoning,
)
return message.model_copy(update={"additional_kwargs": additional_kwargs})
class PatchedChatMiniMax(ChatOpenAI):
"""ChatOpenAI adapter that preserves MiniMax reasoning output."""
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
extra_body = payload.get("extra_body")
if isinstance(extra_body, dict):
payload["extra_body"] = {
**extra_body,
"reasoning_split": True,
}
else:
payload["extra_body"] = {"reasoning_split": True}
return payload
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: type,
base_generation_info: dict | None,
) -> ChatGenerationChunk | None:
if chunk.get("type") == "content.delta":
return None
token_usage = chunk.get("usage")
choices = chunk.get("choices", []) or chunk.get("chunk", {}).get("choices", [])
usage_metadata = _create_usage_metadata(token_usage, chunk.get("service_tier")) if token_usage else None
if len(choices) == 0:
generation_chunk = ChatGenerationChunk(
message=default_chunk_class(content="", usage_metadata=usage_metadata),
generation_info=base_generation_info,
)
if self.output_version == "v1":
generation_chunk.message.content = []
generation_chunk.message.response_metadata["output_version"] = "v1"
return generation_chunk
choice = choices[0]
delta = choice.get("delta")
if delta is None:
return None
message_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
generation_info = {**base_generation_info} if base_generation_info else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
if service_tier := chunk.get("service_tier"):
generation_info["service_tier"] = service_tier
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
reasoning = _extract_reasoning_text(
delta.get("reasoning_details"),
strip_parts=False,
)
if isinstance(message_chunk, AIMessageChunk):
if usage_metadata:
message_chunk.usage_metadata = usage_metadata
if reasoning:
message_chunk = _with_reasoning_content(
message_chunk,
reasoning,
preserve_whitespace=True,
)
message_chunk.response_metadata["model_provider"] = "openai"
return ChatGenerationChunk(
message=message_chunk,
generation_info=generation_info or None,
)
def _create_chat_result(
self,
response: dict | Any,
generation_info: dict | None = None,
) -> ChatResult:
result = super()._create_chat_result(response, generation_info)
response_dict = response if isinstance(response, dict) else response.model_dump()
choices = response_dict.get("choices", [])
generations: list[ChatGeneration] = []
for index, generation in enumerate(result.generations):
choice = choices[index] if index < len(choices) else {}
message = generation.message
if isinstance(message, AIMessage):
content = message.content if isinstance(message.content, str) else None
cleaned_content = content
inline_reasoning = None
if isinstance(content, str):
cleaned_content, inline_reasoning = _strip_inline_think_tags(content)
choice_message = choice.get("message", {}) if isinstance(choice, Mapping) else {}
split_reasoning = _extract_reasoning_text(choice_message.get("reasoning_details"))
merged_reasoning = _merge_reasoning(split_reasoning, inline_reasoning)
updated_message = message
if cleaned_content is not None and cleaned_content != message.content:
updated_message = updated_message.model_copy(update={"content": cleaned_content})
if merged_reasoning:
updated_message = _with_reasoning_content(updated_message, merged_reasoning)
generation = ChatGeneration(
message=updated_message,
generation_info=generation.generation_info,
)
generations.append(generation)
return ChatResult(generations=generations, llm_output=result.llm_output)

View File

@@ -0,0 +1,132 @@
"""Patched ChatOpenAI that preserves thought_signature for Gemini thinking models.
When using Gemini with thinking enabled via an OpenAI-compatible gateway (e.g.
Vertex AI, Google AI Studio, or any proxy), the API requires that the
``thought_signature`` field on tool-call objects is echoed back verbatim in
every subsequent request.
The OpenAI-compatible gateway stores the raw tool-call dicts (including
``thought_signature``) in ``additional_kwargs["tool_calls"]``, but standard
``langchain_openai.ChatOpenAI`` only serialises the standard fields (``id``,
``type``, ``function``) into the outgoing payload, silently dropping the
signature. That causes an HTTP 400 ``INVALID_ARGUMENT`` error:
Unable to submit request because function call `<tool>` in the N. content
block is missing a `thought_signature`.
This module fixes the problem by overriding ``_get_request_payload`` to
re-inject tool-call signatures back into the outgoing payload for any assistant
message that originally carried them.
"""
from __future__ import annotations
from typing import Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
class PatchedChatOpenAI(ChatOpenAI):
"""ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway.
When using Gemini with thinking enabled via an OpenAI-compatible gateway,
the API expects ``thought_signature`` to be present on tool-call objects in
multi-turn conversations. This patched version restores those signatures
from ``AIMessage.additional_kwargs["tool_calls"]`` into the serialised
request payload before it is sent to the API.
Usage in ``config.yaml``::
- name: gemini-2.5-pro-thinking
display_name: Gemini 2.5 Pro (Thinking)
use: deerflow.models.patched_openai:PatchedChatOpenAI
model: google/gemini-2.5-pro-preview
api_key: $GEMINI_API_KEY
base_url: https://<your-openai-compat-gateway>/v1
max_tokens: 16384
supports_thinking: true
supports_vision: true
when_thinking_enabled:
extra_body:
thinking:
type: enabled
"""
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
"""Get request payload with ``thought_signature`` preserved on tool-call objects.
Overrides the parent method to re-inject ``thought_signature`` fields
on tool-call objects that were stored in
``additional_kwargs["tool_calls"]`` by LangChain but dropped during
serialisation.
"""
# Capture the original LangChain messages *before* conversion so we can
# access fields that the serialiser might drop.
original_messages = self._convert_input(input_).to_messages()
# Obtain the base payload from the parent implementation.
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
payload_messages = payload.get("messages", [])
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
_restore_tool_call_signatures(payload_msg, orig_msg)
else:
# Fallback: match assistant-role entries positionally against AIMessages.
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
for (_, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
_restore_tool_call_signatures(payload_msg, ai_msg)
return payload
def _restore_tool_call_signatures(payload_msg: dict, orig_msg: AIMessage) -> None:
"""Re-inject ``thought_signature`` onto tool-call objects in *payload_msg*.
When the Gemini OpenAI-compatible gateway returns a response with function
calls, each tool-call object may carry a ``thought_signature``. LangChain
stores the raw tool-call dicts in ``additional_kwargs["tool_calls"]`` but
only serialises the standard fields (``id``, ``type``, ``function``) into
the outgoing payload, silently dropping the signature.
This function matches raw tool-call entries (by ``id``, falling back to
positional order) and copies the signature back onto the serialised
payload entries.
"""
raw_tool_calls: list[dict] = orig_msg.additional_kwargs.get("tool_calls") or []
payload_tool_calls: list[dict] = payload_msg.get("tool_calls") or []
if not raw_tool_calls or not payload_tool_calls:
return
# Build an id → raw_tc lookup for efficient matching.
raw_by_id: dict[str, dict] = {}
for raw_tc in raw_tool_calls:
tc_id = raw_tc.get("id")
if tc_id:
raw_by_id[tc_id] = raw_tc
for idx, payload_tc in enumerate(payload_tool_calls):
# Try matching by id first, then fall back to positional.
raw_tc = raw_by_id.get(payload_tc.get("id", ""))
if raw_tc is None and idx < len(raw_tool_calls):
raw_tc = raw_tool_calls[idx]
if raw_tc is None:
continue
# The gateway may use either snake_case or camelCase.
sig = raw_tc.get("thought_signature") or raw_tc.get("thoughtSignature")
if sig:
payload_tc["thought_signature"] = sig

View File

@@ -0,0 +1,258 @@
"""Custom vLLM provider built on top of LangChain ChatOpenAI.
vLLM 0.19.0 exposes reasoning models through an OpenAI-compatible API, but
LangChain's default OpenAI adapter drops the non-standard ``reasoning`` field
from assistant messages and streaming deltas. That breaks interleaved
thinking/tool-call flows because vLLM expects the assistant's prior reasoning to
be echoed back on subsequent turns.
This provider preserves ``reasoning`` on:
- non-streaming responses
- streaming deltas
- multi-turn request payloads
"""
from __future__ import annotations
import json
from collections.abc import Mapping
from typing import Any, cast
import openai
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _create_usage_metadata
def _normalize_vllm_chat_template_kwargs(payload: dict[str, Any]) -> None:
"""Map DeerFlow's legacy ``thinking`` toggle to vLLM/Qwen's ``enable_thinking``.
DeerFlow originally documented ``extra_body.chat_template_kwargs.thinking``
for vLLM, but vLLM 0.19.0's Qwen reasoning parser reads
``chat_template_kwargs.enable_thinking``. Normalize the payload just before
it is sent so existing configs keep working and flash mode can truly
disable reasoning.
"""
extra_body = payload.get("extra_body")
if not isinstance(extra_body, dict):
return
chat_template_kwargs = extra_body.get("chat_template_kwargs")
if not isinstance(chat_template_kwargs, dict):
return
if "thinking" not in chat_template_kwargs:
return
normalized_chat_template_kwargs = dict(chat_template_kwargs)
normalized_chat_template_kwargs.setdefault("enable_thinking", normalized_chat_template_kwargs["thinking"])
normalized_chat_template_kwargs.pop("thinking", None)
extra_body["chat_template_kwargs"] = normalized_chat_template_kwargs
def _reasoning_to_text(reasoning: Any) -> str:
"""Best-effort extraction of readable reasoning text from vLLM payloads."""
if isinstance(reasoning, str):
return reasoning
if isinstance(reasoning, list):
parts = [_reasoning_to_text(item) for item in reasoning]
return "".join(part for part in parts if part)
if isinstance(reasoning, dict):
for key in ("text", "content", "reasoning"):
value = reasoning.get(key)
if isinstance(value, str):
return value
if value is not None:
text = _reasoning_to_text(value)
if text:
return text
try:
return json.dumps(reasoning, ensure_ascii=False)
except TypeError:
return str(reasoning)
try:
return json.dumps(reasoning, ensure_ascii=False)
except TypeError:
return str(reasoning)
def _convert_delta_to_message_chunk_with_reasoning(_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]) -> BaseMessageChunk:
"""Convert a streaming delta to a LangChain message chunk while preserving reasoning."""
id_ = _dict.get("id")
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: dict[str, Any] = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
reasoning = _dict.get("reasoning")
if reasoning is not None:
additional_kwargs["reasoning"] = reasoning
reasoning_text = _reasoning_to_text(reasoning)
if reasoning_text:
additional_kwargs["reasoning_content"] = reasoning_text
tool_call_chunks = []
if raw_tool_calls := _dict.get("tool_calls"):
try:
tool_call_chunks = [
tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc["index"],
)
for rtc in raw_tool_calls
]
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content, id=id_)
if role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
id=id_,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
)
if role in ("system", "developer") or default_class == SystemMessageChunk:
role_kwargs = {"__openai_role__": "developer"} if role == "developer" else {}
return SystemMessageChunk(content=content, id=id_, additional_kwargs=role_kwargs)
if role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
if role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"], id=id_)
if role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role, id=id_) # type: ignore[arg-type]
return default_class(content=content, id=id_) # type: ignore[call-arg]
def _restore_reasoning_field(payload_msg: dict[str, Any], orig_msg: AIMessage) -> None:
"""Re-inject vLLM reasoning onto outgoing assistant messages."""
reasoning = orig_msg.additional_kwargs.get("reasoning")
if reasoning is None:
reasoning = orig_msg.additional_kwargs.get("reasoning_content")
if reasoning is not None:
payload_msg["reasoning"] = reasoning
class VllmChatModel(ChatOpenAI):
"""ChatOpenAI variant that preserves vLLM reasoning fields across turns."""
model_config = {"arbitrary_types_allowed": True}
@property
def _llm_type(self) -> str:
return "vllm-openai-compatible"
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
"""Restore assistant reasoning in request payloads for interleaved thinking."""
original_messages = self._convert_input(input_).to_messages()
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
_normalize_vllm_chat_template_kwargs(payload)
payload_messages = payload.get("messages", [])
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
_restore_reasoning_field(payload_msg, orig_msg)
else:
ai_messages = [message for message in original_messages if isinstance(message, AIMessage)]
assistant_payloads = [message for message in payload_messages if message.get("role") == "assistant"]
for payload_msg, ai_msg in zip(assistant_payloads, ai_messages):
_restore_reasoning_field(payload_msg, ai_msg)
return payload
def _create_chat_result(self, response: dict | openai.BaseModel, generation_info: dict | None = None) -> ChatResult:
"""Preserve vLLM reasoning on non-streaming responses."""
result = super()._create_chat_result(response, generation_info=generation_info)
response_dict = response if isinstance(response, dict) else response.model_dump()
for generation, choice in zip(result.generations, response_dict.get("choices", [])):
if not isinstance(generation, ChatGeneration):
continue
message = generation.message
if not isinstance(message, AIMessage):
continue
reasoning = choice.get("message", {}).get("reasoning")
if reasoning is None:
continue
message.additional_kwargs["reasoning"] = reasoning
reasoning_text = _reasoning_to_text(reasoning)
if reasoning_text:
message.additional_kwargs["reasoning_content"] = reasoning_text
return result
def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: type,
base_generation_info: dict | None,
) -> ChatGenerationChunk | None:
"""Preserve vLLM reasoning on streaming deltas."""
if chunk.get("type") == "content.delta":
return None
token_usage = chunk.get("usage")
choices = chunk.get("choices", []) or chunk.get("chunk", {}).get("choices", [])
usage_metadata = _create_usage_metadata(token_usage, chunk.get("service_tier")) if token_usage else None
if len(choices) == 0:
generation_chunk = ChatGenerationChunk(message=default_chunk_class(content="", usage_metadata=usage_metadata), generation_info=base_generation_info)
if self.output_version == "v1":
generation_chunk.message.content = []
generation_chunk.message.response_metadata["output_version"] = "v1"
return generation_chunk
choice = choices[0]
if choice["delta"] is None:
return None
message_chunk = _convert_delta_to_message_chunk_with_reasoning(choice["delta"], default_chunk_class)
generation_info = {**base_generation_info} if base_generation_info else {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"):
generation_info["model_name"] = model_name
if system_fingerprint := chunk.get("system_fingerprint"):
generation_info["system_fingerprint"] = system_fingerprint
if service_tier := chunk.get("service_tier"):
generation_info["service_tier"] = service_tier
if logprobs := choice.get("logprobs"):
generation_info["logprobs"] = logprobs
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
message_chunk.usage_metadata = usage_metadata
message_chunk.response_metadata["model_provider"] = "openai"
return ChatGenerationChunk(message=message_chunk, generation_info=generation_info or None)