Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection hardening: - New deerflow.security package: content_delimiter, html_cleaner, sanitizer (8 layers — invisible chars, control chars, symbols, NFC, PUA, tag chars, horizontal whitespace collapse with newline/tab preservation, length cap) - New deerflow.community.searx package: web_search, web_fetch, image_search backed by a private SearX instance, every external string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>> delimiters - All native community web providers (ddg_search, tavily, exa, firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail stubs that raise NativeWebToolDisabledError at import time, so a misconfigured tool.use path fails loud rather than silently falling back to unsanitized output - Native client back-doors (jina_client.py, infoquest_client.py) stubbed too - Native-tool tests quarantined under tests/_disabled_native/ (collect_ignore_glob via local conftest.py) - Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve newlines and tabs so list/table structure survives - Hardened runtime config.yaml references only the searx-backed tools - Factory overlay (backend/) kept in sync with deer-flow tree as a reference / source See HARDENING.md for the full audit trail and verification steps.
349 lines
14 KiB
Python
349 lines
14 KiB
Python
"""Custom Claude provider with OAuth Bearer auth, prompt caching, and smart thinking.
|
|
|
|
Supports two authentication modes:
|
|
1. Standard API key (x-api-key header) — default ChatAnthropic behavior
|
|
2. Claude Code OAuth token (Authorization: Bearer header)
|
|
- Detected by sk-ant-oat prefix
|
|
- Requires anthropic-beta: oauth-2025-04-20,claude-code-20250219
|
|
- Requires billing header in system prompt for all OAuth requests
|
|
|
|
Auto-loads credentials from explicit runtime handoff:
|
|
- $ANTHROPIC_API_KEY environment variable
|
|
- $CLAUDE_CODE_OAUTH_TOKEN or $ANTHROPIC_AUTH_TOKEN
|
|
- $CLAUDE_CODE_OAUTH_TOKEN_FILE_DESCRIPTOR
|
|
- $CLAUDE_CODE_CREDENTIALS_PATH
|
|
- ~/.claude/.credentials.json
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import socket
|
|
import time
|
|
import uuid
|
|
from typing import Any
|
|
|
|
import anthropic
|
|
from langchain_anthropic import ChatAnthropic
|
|
from langchain_core.messages import BaseMessage
|
|
from pydantic import PrivateAttr
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MAX_RETRIES = 3
|
|
THINKING_BUDGET_RATIO = 0.8
|
|
|
|
# Billing header required by Anthropic API for OAuth token access.
|
|
# Must be the first system prompt block. Format mirrors Claude Code CLI.
|
|
# Override with ANTHROPIC_BILLING_HEADER env var if the hardcoded version drifts.
|
|
_DEFAULT_BILLING_HEADER = "x-anthropic-billing-header: cc_version=2.1.85.351; cc_entrypoint=cli; cch=6c6d5;"
|
|
OAUTH_BILLING_HEADER = os.environ.get("ANTHROPIC_BILLING_HEADER", _DEFAULT_BILLING_HEADER)
|
|
|
|
|
|
class ClaudeChatModel(ChatAnthropic):
|
|
"""ChatAnthropic with OAuth Bearer auth, prompt caching, and smart thinking.
|
|
|
|
Config example:
|
|
- name: claude-sonnet-4.6
|
|
use: deerflow.models.claude_provider:ClaudeChatModel
|
|
model: claude-sonnet-4-6
|
|
max_tokens: 16384
|
|
enable_prompt_caching: true
|
|
"""
|
|
|
|
# Custom fields
|
|
enable_prompt_caching: bool = True
|
|
prompt_cache_size: int = 3
|
|
auto_thinking_budget: bool = True
|
|
retry_max_attempts: int = MAX_RETRIES
|
|
_is_oauth: bool = PrivateAttr(default=False)
|
|
_oauth_access_token: str = PrivateAttr(default="")
|
|
|
|
model_config = {"arbitrary_types_allowed": True}
|
|
|
|
def _validate_retry_config(self) -> None:
|
|
if self.retry_max_attempts < 1:
|
|
raise ValueError("retry_max_attempts must be >= 1")
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""Auto-load credentials and configure OAuth if needed."""
|
|
from pydantic import SecretStr
|
|
|
|
from deerflow.models.credential_loader import (
|
|
OAUTH_ANTHROPIC_BETAS,
|
|
is_oauth_token,
|
|
load_claude_code_credential,
|
|
)
|
|
|
|
self._validate_retry_config()
|
|
|
|
# Extract actual key value (SecretStr.str() returns '**********')
|
|
current_key = ""
|
|
if self.anthropic_api_key:
|
|
if hasattr(self.anthropic_api_key, "get_secret_value"):
|
|
current_key = self.anthropic_api_key.get_secret_value()
|
|
else:
|
|
current_key = str(self.anthropic_api_key)
|
|
|
|
# Try the explicit Claude Code OAuth handoff sources if no valid key.
|
|
if not current_key or current_key in ("your-anthropic-api-key",):
|
|
cred = load_claude_code_credential()
|
|
if cred:
|
|
current_key = cred.access_token
|
|
logger.info(f"Using Claude Code CLI credential (source: {cred.source})")
|
|
else:
|
|
logger.warning("No Anthropic API key or explicit Claude Code OAuth credential found.")
|
|
|
|
# Detect OAuth token and configure Bearer auth
|
|
if is_oauth_token(current_key):
|
|
self._is_oauth = True
|
|
self._oauth_access_token = current_key
|
|
# Set the token as api_key temporarily (will be swapped to auth_token on client)
|
|
self.anthropic_api_key = SecretStr(current_key)
|
|
# Add required beta headers for OAuth
|
|
self.default_headers = {
|
|
**(self.default_headers or {}),
|
|
"anthropic-beta": OAUTH_ANTHROPIC_BETAS,
|
|
}
|
|
# OAuth tokens have a limit of 4 cache_control blocks — disable prompt caching
|
|
self.enable_prompt_caching = False
|
|
logger.info("OAuth token detected — will use Authorization: Bearer header")
|
|
else:
|
|
if current_key:
|
|
self.anthropic_api_key = SecretStr(current_key)
|
|
|
|
# Ensure api_key is SecretStr
|
|
if isinstance(self.anthropic_api_key, str):
|
|
self.anthropic_api_key = SecretStr(self.anthropic_api_key)
|
|
|
|
super().model_post_init(__context)
|
|
|
|
# Patch clients immediately after creation for OAuth Bearer auth.
|
|
# This must happen after super() because clients are lazily created.
|
|
if self._is_oauth:
|
|
self._patch_client_oauth(self._client)
|
|
self._patch_client_oauth(self._async_client)
|
|
|
|
def _patch_client_oauth(self, client: Any) -> None:
|
|
"""Swap api_key → auth_token on an Anthropic SDK client for OAuth Bearer auth."""
|
|
if hasattr(client, "api_key") and hasattr(client, "auth_token"):
|
|
client.api_key = None
|
|
client.auth_token = self._oauth_access_token
|
|
|
|
def _get_request_payload(
|
|
self,
|
|
input_: Any,
|
|
*,
|
|
stop: list[str] | None = None,
|
|
**kwargs: Any,
|
|
) -> dict:
|
|
"""Override to inject prompt caching, thinking budget, and OAuth billing."""
|
|
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
|
|
|
if self._is_oauth:
|
|
self._apply_oauth_billing(payload)
|
|
|
|
if self.enable_prompt_caching:
|
|
self._apply_prompt_caching(payload)
|
|
|
|
if self.auto_thinking_budget:
|
|
self._apply_thinking_budget(payload)
|
|
|
|
return payload
|
|
|
|
def _apply_oauth_billing(self, payload: dict) -> None:
|
|
"""Inject the billing header block required for all OAuth requests.
|
|
|
|
The billing block is always placed first in the system list, removing any
|
|
existing occurrence to avoid duplication or out-of-order positioning.
|
|
"""
|
|
billing_block = {"type": "text", "text": OAUTH_BILLING_HEADER}
|
|
|
|
system = payload.get("system")
|
|
if isinstance(system, list):
|
|
# Remove any existing billing blocks, then insert a single one at index 0.
|
|
filtered = [b for b in system if not (isinstance(b, dict) and OAUTH_BILLING_HEADER in b.get("text", ""))]
|
|
payload["system"] = [billing_block] + filtered
|
|
elif isinstance(system, str):
|
|
if OAUTH_BILLING_HEADER in system:
|
|
payload["system"] = [billing_block]
|
|
else:
|
|
payload["system"] = [billing_block, {"type": "text", "text": system}]
|
|
else:
|
|
payload["system"] = [billing_block]
|
|
|
|
# Add metadata.user_id required by the API for OAuth billing validation
|
|
if not isinstance(payload.get("metadata"), dict):
|
|
payload["metadata"] = {}
|
|
if "user_id" not in payload["metadata"]:
|
|
# Generate a stable device_id from the machine's hostname
|
|
hostname = socket.gethostname()
|
|
device_id = hashlib.sha256(f"deerflow-{hostname}".encode()).hexdigest()
|
|
session_id = str(uuid.uuid4())
|
|
payload["metadata"]["user_id"] = json.dumps(
|
|
{
|
|
"device_id": device_id,
|
|
"account_uuid": "deerflow",
|
|
"session_id": session_id,
|
|
}
|
|
)
|
|
|
|
def _apply_prompt_caching(self, payload: dict) -> None:
|
|
"""Apply ephemeral cache_control to system and recent messages."""
|
|
# Cache system messages
|
|
system = payload.get("system")
|
|
if system and isinstance(system, list):
|
|
for block in system:
|
|
if isinstance(block, dict) and block.get("type") == "text":
|
|
block["cache_control"] = {"type": "ephemeral"}
|
|
elif system and isinstance(system, str):
|
|
payload["system"] = [
|
|
{
|
|
"type": "text",
|
|
"text": system,
|
|
"cache_control": {"type": "ephemeral"},
|
|
}
|
|
]
|
|
|
|
# Cache recent messages
|
|
messages = payload.get("messages", [])
|
|
cache_start = max(0, len(messages) - self.prompt_cache_size)
|
|
for i in range(cache_start, len(messages)):
|
|
msg = messages[i]
|
|
if not isinstance(msg, dict):
|
|
continue
|
|
content = msg.get("content")
|
|
if isinstance(content, list):
|
|
for block in content:
|
|
if isinstance(block, dict):
|
|
block["cache_control"] = {"type": "ephemeral"}
|
|
elif isinstance(content, str) and content:
|
|
msg["content"] = [
|
|
{
|
|
"type": "text",
|
|
"text": content,
|
|
"cache_control": {"type": "ephemeral"},
|
|
}
|
|
]
|
|
|
|
# Cache the last tool definition
|
|
tools = payload.get("tools", [])
|
|
if tools and isinstance(tools[-1], dict):
|
|
tools[-1]["cache_control"] = {"type": "ephemeral"}
|
|
|
|
def _apply_thinking_budget(self, payload: dict) -> None:
|
|
"""Auto-allocate thinking budget (80% of max_tokens)."""
|
|
thinking = payload.get("thinking")
|
|
if not thinking or not isinstance(thinking, dict):
|
|
return
|
|
if thinking.get("type") != "enabled":
|
|
return
|
|
if thinking.get("budget_tokens"):
|
|
return
|
|
|
|
max_tokens = payload.get("max_tokens", 8192)
|
|
thinking["budget_tokens"] = int(max_tokens * THINKING_BUDGET_RATIO)
|
|
|
|
@staticmethod
|
|
def _strip_cache_control(payload: dict) -> None:
|
|
"""Remove cache_control markers before OAuth requests reach Anthropic."""
|
|
for section in ("system", "messages"):
|
|
items = payload.get(section)
|
|
if not isinstance(items, list):
|
|
continue
|
|
for item in items:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
item.pop("cache_control", None)
|
|
content = item.get("content")
|
|
if isinstance(content, list):
|
|
for block in content:
|
|
if isinstance(block, dict):
|
|
block.pop("cache_control", None)
|
|
|
|
tools = payload.get("tools")
|
|
if isinstance(tools, list):
|
|
for tool in tools:
|
|
if isinstance(tool, dict):
|
|
tool.pop("cache_control", None)
|
|
|
|
def _create(self, payload: dict) -> Any:
|
|
if self._is_oauth:
|
|
self._strip_cache_control(payload)
|
|
return super()._create(payload)
|
|
|
|
async def _acreate(self, payload: dict) -> Any:
|
|
if self._is_oauth:
|
|
self._strip_cache_control(payload)
|
|
return await super()._acreate(payload)
|
|
|
|
def _generate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
|
|
"""Override with OAuth patching and retry logic."""
|
|
if self._is_oauth:
|
|
self._patch_client_oauth(self._client)
|
|
|
|
last_error = None
|
|
for attempt in range(1, self.retry_max_attempts + 1):
|
|
try:
|
|
return super()._generate(messages, stop=stop, **kwargs)
|
|
except anthropic.RateLimitError as e:
|
|
last_error = e
|
|
if attempt >= self.retry_max_attempts:
|
|
raise
|
|
wait_ms = self._calc_backoff_ms(attempt, e)
|
|
logger.warning(f"Rate limited, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
|
time.sleep(wait_ms / 1000)
|
|
except anthropic.InternalServerError as e:
|
|
last_error = e
|
|
if attempt >= self.retry_max_attempts:
|
|
raise
|
|
wait_ms = self._calc_backoff_ms(attempt, e)
|
|
logger.warning(f"Server error, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
|
time.sleep(wait_ms / 1000)
|
|
raise last_error
|
|
|
|
async def _agenerate(self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any) -> Any:
|
|
"""Async override with OAuth patching and retry logic."""
|
|
import asyncio
|
|
|
|
if self._is_oauth:
|
|
self._patch_client_oauth(self._async_client)
|
|
|
|
last_error = None
|
|
for attempt in range(1, self.retry_max_attempts + 1):
|
|
try:
|
|
return await super()._agenerate(messages, stop=stop, **kwargs)
|
|
except anthropic.RateLimitError as e:
|
|
last_error = e
|
|
if attempt >= self.retry_max_attempts:
|
|
raise
|
|
wait_ms = self._calc_backoff_ms(attempt, e)
|
|
logger.warning(f"Rate limited, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
|
await asyncio.sleep(wait_ms / 1000)
|
|
except anthropic.InternalServerError as e:
|
|
last_error = e
|
|
if attempt >= self.retry_max_attempts:
|
|
raise
|
|
wait_ms = self._calc_backoff_ms(attempt, e)
|
|
logger.warning(f"Server error, retrying attempt {attempt}/{self.retry_max_attempts} after {wait_ms}ms")
|
|
await asyncio.sleep(wait_ms / 1000)
|
|
raise last_error
|
|
|
|
@staticmethod
|
|
def _calc_backoff_ms(attempt: int, error: Exception) -> int:
|
|
"""Exponential backoff with a fixed 20% buffer."""
|
|
backoff_ms = 2000 * (1 << (attempt - 1))
|
|
jitter_ms = int(backoff_ms * 0.2)
|
|
total_ms = backoff_ms + jitter_ms
|
|
|
|
if hasattr(error, "response") and error.response is not None:
|
|
retry_after = error.response.headers.get("Retry-After")
|
|
if retry_after:
|
|
try:
|
|
total_ms = int(retry_after) * 1000
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
return total_ms
|