Initial commit: hardened DeerFlow factory
Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection hardening: - New deerflow.security package: content_delimiter, html_cleaner, sanitizer (8 layers — invisible chars, control chars, symbols, NFC, PUA, tag chars, horizontal whitespace collapse with newline/tab preservation, length cap) - New deerflow.community.searx package: web_search, web_fetch, image_search backed by a private SearX instance, every external string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>> delimiters - All native community web providers (ddg_search, tavily, exa, firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail stubs that raise NativeWebToolDisabledError at import time, so a misconfigured tool.use path fails loud rather than silently falling back to unsanitized output - Native client back-doors (jina_client.py, infoquest_client.py) stubbed too - Native-tool tests quarantined under tests/_disabled_native/ (collect_ignore_glob via local conftest.py) - Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve newlines and tabs so list/table structure survives - Hardened runtime config.yaml references only the searx-backed tools - Factory overlay (backend/) kept in sync with deer-flow tree as a reference / source See HARDENING.md for the full audit trail and verification steps.
This commit is contained in:
@@ -0,0 +1,348 @@
|
||||
"""Custom Claude provider with OAuth Bearer auth, prompt caching, and smart thinking.
|
||||
|
||||
Supports two authentication modes:
|
||||
1. Standard API key (x-api-key header) — default ChatAnthropic behavior
|
||||
2. Claude Code OAuth token (Authorization: Bearer header)
|
||||
- Detected by sk-ant-oat prefix
|
||||
- Requires anthropic-beta: oauth-2025-04-20,claude-code-20250219
|
||||
- Requires billing header in system prompt for all OAuth requests
|
||||
|
||||
Auto-loads credentials from explicit runtime handoff:
|
||||
- $ANTHROPIC_API_KEY environment variable
|
||||
- $CLAUDE_CODE_OAUTH_TOKEN or $ANTHROPIC_AUTH_TOKEN
|
||||
- $CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR
|
||||
- $CLAUDE_CODE_CREDENTIALS_PATH
|
||||
- ~/.claude/.credentials.json
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import anthropic
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_RETRIES = 3
|
||||
THINKING_BUDGET_RATIO = 0.8
|
||||
|
||||
# Billing header required by Anthropic API for OAuth token access.
|
||||
# Must be the first system prompt block. Format mirrors Claude Code CLI.
|
||||
# Override with ANTHROPIC_BILLING_HEADER env var if the hardcoded version drifts.
|
||||
_DEFAULT_BILLING_HEADER = "x-anthropic-billing-header: cc_version=2.1.85.351; cc_entrypoint=cli; cch=6c6d5;"
|
||||
OAUTH_BILLING_HEADER = os.environ.get("ANTHROPIC_BILLING_HEADER", _DEFAULT_BILLING_HEADER)
|
||||
|
||||
|
||||
class ClaudeChatModel(ChatAnthropic):
|
||||
"""ChatAnthropic with OAuth Bearer auth, prompt caching, and smart thinking.
|
||||
|
||||
Config example:
|
||||
- name: claude-sonnet-4.6
|
||||
use: deerflow.models.claude_provider:ClaudeChatModel
|
||||
model: claude-sonnet-4-6
|
||||
max_tokens: 16384
|
||||
enable_prompt_caching: true
|
||||
"""
|
||||
|
||||
# Custom fields
|
||||
enable_prompt_caching: bool = True
|
||||
prompt_cache_size: int = 3
|
||||
auto_thinking_budget: bool = True
|
||||
retry_max_attempts: int = MAX_RETRIES
|
||||
_is_oauth: bool = PrivateAttr(default=False)
|
||||
_oauth_access_token: str = PrivateAttr(default="")
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
def _validate_retry_config(self) -> None:
|
||||
if self.retry_max_attempts < 1:
|
||||
raise ValueError("retry_max_attempts must be >= 1")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Auto-load credentials and configure OAuth if needed."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from deerflow.models.credential_loader import (
|
||||
OAUTH_ANTHROPIC_BETAS,
|
||||
is_oauth_token,
|
||||
load_claude_code_credential,
|
||||
)
|
||||
|
||||
self._validate_retry_config()
|
||||
|
||||
# Extract actual key value (SecretStr.str() returns '**********')
|
||||
current_key = ""
|
||||
if self.anthropic_api_key:
|
||||
if hasattr(self.anthropic_api_key, "get_secret_value"):
|
||||
current_key = self.anthropic_api_key.get_secret_value()
|
||||
else:
|
||||
current_key = str(self.anthropic_api_key)
|
||||
|
||||
# Try the explicit Claude Code OAuth handoff sources if no valid key.
|
||||
if not current_key or current_key in ("your-anthropic-api-key",):
|
||||
cred = load_claude_code_credential()
|
||||
if cred:
|
||||
current_key = cred.access_token
|
||||
logger.info(f"Using Claude Code CLI credential (source: {cred.source})")
|
||||
else:
|
||||
logger.warning("No Anthropic API key or explicit Claude Code OAuth credential found.")
|
||||
|
||||
# Detect OAuth token and configure Bearer auth
|
||||
if is_oauth_token(current_key):
|
||||
self._is_oauth = True
|
||||
self._oauth_access_token = current_key
|
||||
# Set the token as api_key temporarily (will be swapped to auth_token on client)
|
||||
self.anthropic_api_key = SecretStr(current_key)
|
||||
# Add required beta headers for OAuth
|
||||
self.default_headers = {
|
||||
**(self.default_headers or {}),
|
||||
"anthropic-beta": OAUTH_ANTHROPIC_BETAS,
|
||||
}
|
||||
# OAuth tokens have a limit of 4 cache_control blocks — disable prompt caching
|
||||
self.enable_prompt_caching = False
|
||||
logger.info("OAuth token detected — will use Authorization: Bearer header")
|
||||
else:
|
||||
if current_key:
|
||||
self.anthropic_api_key = SecretStr(current_key)
|
||||
|
||||
# Ensure api_key is SecretStr
|
||||
if isinstance(self.anthropic_api_key, str):
|
||||
self.anthropic_api_key = SecretStr(self.anthropic_api_key)
|
||||
|
||||
super().model_post_init(__context)
|
||||
|
||||
# Patch clients immediately after creation for OAuth Bearer auth.
|
||||
# This must happen after super() because clients are lazily created.
|
||||
if self._is_oauth:
|
||||
self._patch_client_oauth(self._client)
|
||||
self._patch_client_oauth(self._async_client)
|
||||
|
||||
def _patch_client_oauth(self, client: Any) -> None:
|
||||
"""Swap api_key → auth_token on an Anthropic SDK client for OAuth Bearer auth."""
|
||||
if hasattr(client, "api_key") and hasattr(client, "auth_token"):
|
||||
client.api_key = None
|
||||
client.auth_token = self._oauth_access_token
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: Any,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Override to inject prompt caching, thinking budget, and OAuth billing."""
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
|
||||
if self._is_oauth:
|
||||
self._apply_oauth_billing(payload)
|
||||
|
||||
if self.enable_prompt_caching:
|
||||
self._apply_prompt_caching(payload)
|
||||
|
||||
if self.auto_thinking_budget:
|
||||
self._apply_thinking_budget(payload)
|
||||
|
||||
return payload
|
||||
|
||||
def _apply_oauth_billing(self, payload: dict) -> None:
|
||||
"""Inject the billing header block required for all OAuth requests.
|
||||
|
||||
The billing block is always placed first in the system list, removing any
|
||||
existing occurrence to avoid duplication or out-of-order positioning.
|
||||
"""
|
||||
billing_block = {"type": "text", "text": OAUTH_BILLING_HEADER}
|
||||
|
||||
system = payload.get("system")
|
||||
if isinstance(system, list):
|
||||
# Remove any existing billing blocks, then insert a single one at index 0.
|
||||
filtered = [b for b in system if not (isinstance(b, dict) and OAUTH_BILLING_HEADER in b.get("text", ""))]
|
||||
payload["system"] = [billing_block] + filtered
|
||||
elif isinstance(system, str):
|
||||
if OAUTH_BILLING_HEADER in system:
|
||||
payload["system"] = [billing_block]
|
||||
else:
|
||||
payload["system"] = [billing_block, {"type": "text", "text": system}]
|
||||
else:
|
||||
payload["system"] = [billing_block]
|
||||
|
||||
# Add metadata.user_id required by the API for OAuth billing validation
|
||||
if not isinstance(payload.get("metadata"), dict):
|
||||
payload["metadata"] = {}
|
||||
if "user_id" not in payload["metadata"]:
|
||||
# Generate a stable device_id from the machine's hostname
|
||||
hostname = socket.gethostname()
|
||||
device_id = hashlib.sha256(f"deerflow-{hostname}".encode()).hexdigest()
|
||||
session_id = str(uuid.uuid4())
|
||||
payload["metadata"]["user_id"] = json.dumps(
|
||||
{
|
||||
"device_id": device_id,
|
||||
"account_uuid": "deerflow",
|
||||
"session_id": session_id,
|
||||
}
|
||||
)
|
||||
|
||||
def _apply_prompt_caching(self, payload: dict) -> None:
|
||||
"""Apply ephemeral cache_control to system and recent messages."""
|
||||
# Cache system messages
|
||||
system = payload.get("system")
|
||||
if system and isinstance(system, list):
|
||||
for block in system:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
block["cache_control"] = {"type": "ephemeral"}
|
||||
elif system and isinstance(system, str):
|
||||
payload["system"] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": system,
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
]
|
||||
|
||||
# Cache recent messages
|
||||
messages = payload.get("messages", [])
|
||||
cache_start = max(0, len(messages) - self.prompt_cache_size)
|
||||
for i in range(cache_start, len(messages)):
|
||||
msg = messages[i]
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict):
|
||||
block["cache_control"] = {"type": "ephemeral"}
|
||||
elif isinstance(content, str) and content:
|
||||
msg["content"] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": content,
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
]
|
||||
|
||||
# Cache the last tool definition
|
||||
tools = payload.get("tools", [])
|
||||
if tools and isinstance(tools[-1], dict):
|
||||
tools[-1]["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
def _apply_thinking_budget(self, payload: dict) -> None:
|
||||
"""Auto-allocate thinking budget (80% of max_tokens)."""
|
||||
thinking = payload.get("thinking")
|
||||
if not thinking or not isinstance(thinking, dict):
|
||||
return
|
||||
if thinking.get("type") != "enabled":
|
||||
return
|
||||
if thinking.get("budget_tokens"):
|
||||
return
|
||||
|
||||
max_tokens = payload.get("max_tokens", 8192)
|
||||
thinking["budget_tokens"] = int(max_tokens * THINKING_BUDGET_RATIO)
|
||||
|
||||
@staticmethod
|
||||
def _strip_cache_control(payload: dict) -> None:
|
||||
"""Remove cache_control markers before OAuth requests reach Anthropic."""
|
||||
for section in ("system", "messages"):
|
||||
items = payload.get(section)
|
||||
if not isinstance(items, list):
|
||||
continue
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item.pop("cache_control", None)
|
||||
content = item.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict):
|
||||
block.pop("cache_control", None)
|
||||
|
||||
tools = payload.get("tools")
|
||||
if isinstance(tools, list):
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
tool.pop("cache_control", None)
|
||||
|
||||
def _create(self, payload: dict) -> Any:
|
||||
if self._is_oauth:
|
||||
self._strip_cache_control(payload)
|
||||
return super()._create(payload)
|
||||
|
||||
async def _acreate(self, payload: dict) -> Any:
|
||||
if self._is_oauth:
|
||||
self._strip_cache_control(payload)
|
||||
return await super()._acreate(payload)
|
||||
|
||||
def _generate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
|
||||
"""Override with OAuth patching and retry logic."""
|
||||
if self._is_oauth:
|
||||
self._patch_client_oauth(self._client)
|
||||
|
||||
last_error = None
|
||||
for attempt in range(1, self.retry_max_attempts + 1):
|
||||
try:
|
||||
return super()._generate(messages, stop=stop, **kwargs)
|
||||
except anthropic.RateLimitError as e:
|
||||
last_error = e
|
||||
if attempt >= self.retry_max_attempts:
|
||||
raise
|
||||
wait_ms = self._calc_backoff_ms(attempt, e)
|
||||
logger.warning(f"Rate limited, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
||||
time.sleep(wait_ms / 1000)
|
||||
except anthropic.InternalServerError as e:
|
||||
last_error = e
|
||||
if attempt >= self.retry_max_attempts:
|
||||
raise
|
||||
wait_ms = self._calc_backoff_ms(attempt, e)
|
||||
logger.warning(f"Server error, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
||||
time.sleep(wait_ms / 1000)
|
||||
raise last_error
|
||||
|
||||
async def _agenerate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
|
||||
"""Async override with OAuth patching and retry logic."""
|
||||
import asyncio
|
||||
|
||||
if self._is_oauth:
|
||||
self._patch_client_oauth(self._async_client)
|
||||
|
||||
last_error = None
|
||||
for attempt in range(1, self.retry_max_attempts + 1):
|
||||
try:
|
||||
return await super()._agenerate(messages, stop=stop, **kwargs)
|
||||
except anthropic.RateLimitError as e:
|
||||
last_error = e
|
||||
if attempt >= self.retry_max_attempts:
|
||||
raise
|
||||
wait_ms = self._calc_backoff_ms(attempt, e)
|
||||
logger.warning(f"Rate limited, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
||||
await asyncio.sleep(wait_ms / 1000)
|
||||
except anthropic.InternalServerError as e:
|
||||
last_error = e
|
||||
if attempt >= self.retry_max_attempts:
|
||||
raise
|
||||
wait_ms = self._calc_backoff_ms(attempt, e)
|
||||
logger.warning(f"Server error, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
||||
await asyncio.sleep(wait_ms / 1000)
|
||||
raise last_error
|
||||
|
||||
@staticmethod
|
||||
def _calc_backoff_ms(attempt: int, error: Exception) -> int:
|
||||
"""Exponential backoff with a fixed 20% buffer."""
|
||||
backoff_ms = 2000 * (1 << (attempt - 1))
|
||||
jitter_ms = int(backoff_ms * 0.2)
|
||||
total_ms = backoff_ms + jitter_ms
|
||||
|
||||
if hasattr(error, "response") and error.response is not None:
|
||||
retry_after = error.response.headers.get("Retry-After")
|
||||
if retry_after:
|
||||
try:
|
||||
total_ms = int(retry_after) * 1000
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return total_ms
|
||||
Reference in New Issue
Block a user