Initial commit: hardened DeerFlow factory

Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection
hardening:

- New deerflow.security package: content_delimiter, html_cleaner,
  sanitizer (8 layers — invisible chars, control chars, symbols, NFC,
  PUA, tag chars, horizontal whitespace collapse with newline/tab
  preservation, length cap)
- New deerflow.community.searx package: web_search, web_fetch,
  image_search backed by a private SearX instance, every external
  string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>>
  delimiters
- All native community web providers (ddg_search, tavily, exa,
  firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail
  stubs that raise NativeWebToolDisabledError at import time, so a
  misconfigured tool.use path fails loud rather than silently falling
  back to unsanitized output
- Native client back-doors (jina_client.py, infoquest_client.py)
  stubbed too
- Native-tool tests quarantined under tests/_disabled_native/
  (collect_ignore_glob via local conftest.py)
- Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve
  newlines and tabs so list/table structure survives
- Hardened runtime config.yaml references only the searx-backed tools
- Factory overlay (backend/) kept in sync with deer-flow tree as a
  reference / source

See HARDENING.md for the full audit trail and verification steps.
This commit is contained in:
2026-04-12 14:23:57 +02:00
commit 6de0bf9f5b
889 changed files with 173052 additions and 0 deletions

View File

View File

@@ -0,0 +1,16 @@
"""IM Channel integration for DeerFlow.
Provides a pluggable channel system that connects external messaging platforms
(Feishu/Lark, Slack, Telegram) to the DeerFlow agent via the ChannelManager,
which uses ``langgraph-sdk`` to communicate with the underlying LangGraph Server.
"""
from app.channels.base import Channel
from app.channels.message_bus import InboundMessage, MessageBus, OutboundMessage
__all__ = [
"Channel",
"InboundMessage",
"MessageBus",
"OutboundMessage",
]

View File

@@ -0,0 +1,126 @@
"""Abstract base class for IM channels."""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Any
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
class Channel(ABC):
"""Base class for all IM channel implementations.
Each channel connects to an external messaging platform and:
1. Receives messages, wraps them as InboundMessage, publishes to the bus.
2. Subscribes to outbound messages and sends replies back to the platform.
Subclasses must implement ``start``, ``stop``, and ``send``.
"""
def __init__(self, name: str, bus: MessageBus, config: dict[str, Any]) -> None:
self.name = name
self.bus = bus
self.config = config
self._running = False
@property
def is_running(self) -> bool:
return self._running
# -- lifecycle ---------------------------------------------------------
@abstractmethod
async def start(self) -> None:
"""Start listening for messages from the external platform."""
@abstractmethod
async def stop(self) -> None:
"""Gracefully stop the channel."""
# -- outbound ----------------------------------------------------------
@abstractmethod
async def send(self, msg: OutboundMessage) -> None:
"""Send a message back to the external platform.
The implementation should use ``msg.chat_id`` and ``msg.thread_ts``
to route the reply to the correct conversation/thread.
"""
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
"""Upload a single file attachment to the platform.
Returns True if the upload succeeded, False otherwise.
Default implementation returns False (no file upload support).
"""
return False
# -- helpers -----------------------------------------------------------
def _make_inbound(
self,
chat_id: str,
user_id: str,
text: str,
*,
msg_type: InboundMessageType = InboundMessageType.CHAT,
thread_ts: str | None = None,
files: list[dict[str, Any]] | None = None,
metadata: dict[str, Any] | None = None,
) -> InboundMessage:
"""Convenience factory for creating InboundMessage instances."""
return InboundMessage(
channel_name=self.name,
chat_id=chat_id,
user_id=user_id,
text=text,
msg_type=msg_type,
thread_ts=thread_ts,
files=files or [],
metadata=metadata or {},
)
async def _on_outbound(self, msg: OutboundMessage) -> None:
"""Outbound callback registered with the bus.
Only forwards messages targeted at this channel.
Sends the text message first, then uploads any file attachments.
File uploads are skipped entirely when the text send fails to avoid
partial deliveries (files without accompanying text).
"""
if msg.channel_name == self.name:
try:
await self.send(msg)
except Exception:
logger.exception("Failed to send outbound message on channel %s", self.name)
return # Do not attempt file uploads when the text message failed
for attachment in msg.attachments:
try:
success = await self.send_file(msg, attachment)
if not success:
logger.warning("[%s] file upload skipped for %s", self.name, attachment.filename)
except Exception:
logger.exception("[%s] failed to upload file %s", self.name, attachment.filename)
async def receive_file(self, msg: InboundMessage, thread_id: str) -> InboundMessage:
"""
Optionally process and materialize inbound file attachments for this channel.
By default, this method does nothing and simply returns the original message.
Subclasses (e.g. FeishuChannel) may override this to download files (images, documents, etc)
referenced in msg.files, save them to the sandbox, and update msg.text to include
the sandbox file paths for downstream model consumption.
Args:
msg: The inbound message, possibly containing file metadata in msg.files.
thread_id: The resolved DeerFlow thread ID for sandbox path context.
Returns:
The (possibly modified) InboundMessage, with text and/or files updated as needed.
"""
return msg

View File

@@ -0,0 +1,20 @@
"""Shared command definitions used by all channel implementations.
Keeping the authoritative command set in one place ensures that channel
parsers (e.g. Feishu) and the ChannelManager dispatcher stay in sync
automatically — adding or removing a command here is the single edit
required.
"""
from __future__ import annotations
KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
{
"/bootstrap",
"/new",
"/status",
"/models",
"/memory",
"/help",
}
)

View File

@@ -0,0 +1,273 @@
"""Discord channel integration using discord.py."""
from __future__ import annotations
import asyncio
import logging
import threading
from typing import Any
from app.channels.base import Channel
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
_DISCORD_MAX_MESSAGE_LEN = 2000
class DiscordChannel(Channel):
"""Discord bot channel.
Configuration keys (in ``config.yaml`` under ``channels.discord``):
- ``bot_token``: Discord Bot token.
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
"""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
super().__init__(name="discord", bus=bus, config=config)
self._bot_token = str(config.get("bot_token", "")).strip()
self._allowed_guilds: set[int] = set()
for guild_id in config.get("allowed_guilds", []):
try:
self._allowed_guilds.add(int(guild_id))
except (TypeError, ValueError):
continue
self._client = None
self._thread: threading.Thread | None = None
self._discord_loop: asyncio.AbstractEventLoop | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None
self._discord_module = None
async def start(self) -> None:
if self._running:
return
try:
import discord
except ImportError:
logger.error("discord.py is not installed. Install it with: uv add discord.py")
return
if not self._bot_token:
logger.error("Discord channel requires bot_token")
return
intents = discord.Intents.default()
intents.messages = True
intents.guilds = True
intents.message_content = True
client = discord.Client(
intents=intents,
allowed_mentions=discord.AllowedMentions.none(),
)
self._client = client
self._discord_module = discord
self._main_loop = asyncio.get_event_loop()
@client.event
async def on_message(message) -> None:
await self._on_message(message)
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
self._thread = threading.Thread(target=self._run_client, daemon=True)
self._thread.start()
logger.info("Discord channel started")
async def stop(self) -> None:
self._running = False
self.bus.unsubscribe_outbound(self._on_outbound)
if self._client and self._discord_loop and self._discord_loop.is_running():
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
try:
await asyncio.wait_for(asyncio.wrap_future(close_future), timeout=10)
except TimeoutError:
logger.warning("[Discord] client close timed out after 10s")
except Exception:
logger.exception("[Discord] error while closing client")
if self._thread:
self._thread.join(timeout=10)
self._thread = None
self._client = None
self._discord_loop = None
self._discord_module = None
logger.info("Discord channel stopped")
async def send(self, msg: OutboundMessage) -> None:
target = await self._resolve_target(msg)
if target is None:
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
return
text = msg.text or ""
for chunk in self._split_text(text):
send_future = asyncio.run_coroutine_threadsafe(target.send(chunk), self._discord_loop)
await asyncio.wrap_future(send_future)
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
target = await self._resolve_target(msg)
if target is None:
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
return False
if self._discord_module is None:
return False
try:
fp = open(str(attachment.actual_path), "rb") # noqa: SIM115
file = self._discord_module.File(fp, filename=attachment.filename)
send_future = asyncio.run_coroutine_threadsafe(target.send(file=file), self._discord_loop)
await asyncio.wrap_future(send_future)
logger.info("[Discord] file uploaded: %s", attachment.filename)
return True
except Exception:
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
return False
async def _on_message(self, message) -> None:
if not self._running or not self._client:
return
if message.author.bot:
return
if self._client.user and message.author.id == self._client.user.id:
return
guild = message.guild
if self._allowed_guilds:
if guild is None or guild.id not in self._allowed_guilds:
return
text = (message.content or "").strip()
if not text:
return
if self._discord_module is None:
return
if isinstance(message.channel, self._discord_module.Thread):
chat_id = str(message.channel.parent_id or message.channel.id)
thread_id = str(message.channel.id)
else:
thread = await self._create_thread(message)
if thread is None:
return
chat_id = str(message.channel.id)
thread_id = str(thread.id)
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=chat_id,
user_id=str(message.author.id),
text=text,
msg_type=msg_type,
thread_ts=thread_id,
metadata={
"guild_id": str(guild.id) if guild else None,
"channel_id": str(message.channel.id),
"message_id": str(message.id),
},
)
inbound.topic_id = thread_id
if self._main_loop and self._main_loop.is_running():
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
def _run_client(self) -> None:
self._discord_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._discord_loop)
try:
self._discord_loop.run_until_complete(self._client.start(self._bot_token))
except Exception:
if self._running:
logger.exception("Discord client error")
finally:
try:
if self._client and not self._client.is_closed():
self._discord_loop.run_until_complete(self._client.close())
except Exception:
logger.exception("Error during Discord shutdown")
async def _create_thread(self, message):
try:
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
return await message.create_thread(name=thread_name)
except Exception:
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
try:
await message.channel.send("Could not create a thread for your message. Please check that threads are enabled in this channel.")
except Exception:
pass
return None
async def _resolve_target(self, msg: OutboundMessage):
if not self._client or not self._discord_loop:
return None
target_ids: list[str] = []
if msg.thread_ts:
target_ids.append(msg.thread_ts)
if msg.chat_id and msg.chat_id not in target_ids:
target_ids.append(msg.chat_id)
for raw_id in target_ids:
target = await self._get_channel_or_thread(raw_id)
if target is not None:
return target
return None
async def _get_channel_or_thread(self, raw_id: str):
if not self._client or not self._discord_loop:
return None
try:
target_id = int(raw_id)
except (TypeError, ValueError):
return None
get_future = asyncio.run_coroutine_threadsafe(self._fetch_channel(target_id), self._discord_loop)
try:
return await asyncio.wrap_future(get_future)
except Exception:
logger.exception("[Discord] failed to resolve target id=%s", raw_id)
return None
async def _fetch_channel(self, target_id: int):
if not self._client:
return None
channel = self._client.get_channel(target_id)
if channel is not None:
return channel
try:
return await self._client.fetch_channel(target_id)
except Exception:
return None
@staticmethod
def _split_text(text: str) -> list[str]:
if not text:
return [""]
chunks: list[str] = []
remaining = text
while len(remaining) > _DISCORD_MAX_MESSAGE_LEN:
split_at = remaining.rfind("\n", 0, _DISCORD_MAX_MESSAGE_LEN)
if split_at <= 0:
split_at = _DISCORD_MAX_MESSAGE_LEN
chunks.append(remaining[:split_at])
remaining = remaining[split_at:].lstrip("\n")
if remaining:
chunks.append(remaining)
return chunks

View File

@@ -0,0 +1,692 @@
"""Feishu/Lark channel — connects to Feishu via WebSocket (no public IP needed)."""
from __future__ import annotations
import asyncio
import json
import logging
import re
import threading
from typing import Any, Literal
from app.channels.base import Channel
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
logger = logging.getLogger(__name__)
def _is_feishu_command(text: str) -> bool:
if not text.startswith("/"):
return False
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
class FeishuChannel(Channel):
"""Feishu/Lark IM channel using the ``lark-oapi`` WebSocket client.
Configuration keys (in ``config.yaml`` under ``channels.feishu``):
- ``app_id``: Feishu app ID.
- ``app_secret``: Feishu app secret.
- ``verification_token``: (optional) Event verification token.
The channel uses WebSocket long-connection mode so no public IP is required.
Message flow:
1. User sends a message → bot adds "OK" emoji reaction
2. Bot replies in thread: "Working on it......"
3. Agent processes the message and returns a result
4. Bot replies in thread with the result
5. Bot adds "DONE" emoji reaction to the original message
"""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
super().__init__(name="feishu", bus=bus, config=config)
self._thread: threading.Thread | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None
self._api_client = None
self._CreateMessageReactionRequest = None
self._CreateMessageReactionRequestBody = None
self._Emoji = None
self._PatchMessageRequest = None
self._PatchMessageRequestBody = None
self._background_tasks: set[asyncio.Task] = set()
self._running_card_ids: dict[str, str] = {}
self._running_card_tasks: dict[str, asyncio.Task] = {}
self._CreateFileRequest = None
self._CreateFileRequestBody = None
self._CreateImageRequest = None
self._CreateImageRequestBody = None
self._GetMessageResourceRequest = None
self._thread_lock = threading.Lock()
async def start(self) -> None:
if self._running:
return
try:
import lark_oapi as lark
from lark_oapi.api.im.v1 import (
CreateFileRequest,
CreateFileRequestBody,
CreateImageRequest,
CreateImageRequestBody,
CreateMessageReactionRequest,
CreateMessageReactionRequestBody,
CreateMessageRequest,
CreateMessageRequestBody,
Emoji,
GetMessageResourceRequest,
PatchMessageRequest,
PatchMessageRequestBody,
ReplyMessageRequest,
ReplyMessageRequestBody,
)
except ImportError:
logger.error("lark-oapi is not installed. Install it with: uv add lark-oapi")
return
self._lark = lark
self._CreateMessageRequest = CreateMessageRequest
self._CreateMessageRequestBody = CreateMessageRequestBody
self._ReplyMessageRequest = ReplyMessageRequest
self._ReplyMessageRequestBody = ReplyMessageRequestBody
self._CreateMessageReactionRequest = CreateMessageReactionRequest
self._CreateMessageReactionRequestBody = CreateMessageReactionRequestBody
self._Emoji = Emoji
self._PatchMessageRequest = PatchMessageRequest
self._PatchMessageRequestBody = PatchMessageRequestBody
self._CreateFileRequest = CreateFileRequest
self._CreateFileRequestBody = CreateFileRequestBody
self._CreateImageRequest = CreateImageRequest
self._CreateImageRequestBody = CreateImageRequestBody
self._GetMessageResourceRequest = GetMessageResourceRequest
app_id = self.config.get("app_id", "")
app_secret = self.config.get("app_secret", "")
domain = self.config.get("domain", "https://open.feishu.cn")
if not app_id or not app_secret:
logger.error("Feishu channel requires app_id and app_secret")
return
self._api_client = lark.Client.builder().app_id(app_id).app_secret(app_secret).domain(domain).build()
logger.info("[Feishu] using domain: %s", domain)
self._main_loop = asyncio.get_event_loop()
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
# Both ws.Client construction and start() must happen in a dedicated
# thread with its own event loop. lark-oapi caches the running loop
# at construction time and later calls loop.run_until_complete(),
# which conflicts with an already-running uvloop.
self._thread = threading.Thread(
target=self._run_ws,
args=(app_id, app_secret, domain),
daemon=True,
)
self._thread.start()
logger.info("Feishu channel started")
def _run_ws(self, app_id: str, app_secret: str, domain: str) -> None:
"""Construct and run the lark WS client in a thread with a fresh event loop.
The lark-oapi SDK captures a module-level event loop at import time
(``lark_oapi.ws.client.loop``). When uvicorn uses uvloop, that
captured loop is the *main* thread's uvloop — which is already
running, so ``loop.run_until_complete()`` inside ``Client.start()``
raises ``RuntimeError``.
We work around this by creating a plain asyncio event loop for this
thread and patching the SDK's module-level reference before calling
``start()``.
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
import lark_oapi as lark
import lark_oapi.ws.client as _ws_client_mod
# Replace the SDK's module-level loop so Client.start() uses
# this thread's (non-running) event loop instead of the main
# thread's uvloop.
_ws_client_mod.loop = loop
event_handler = lark.EventDispatcherHandler.builder("", "").register_p2_im_message_receive_v1(self._on_message).build()
ws_client = lark.ws.Client(
app_id=app_id,
app_secret=app_secret,
event_handler=event_handler,
log_level=lark.LogLevel.INFO,
domain=domain,
)
ws_client.start()
except Exception:
if self._running:
logger.exception("Feishu WebSocket error")
async def stop(self) -> None:
self._running = False
self.bus.unsubscribe_outbound(self._on_outbound)
for task in list(self._background_tasks):
task.cancel()
self._background_tasks.clear()
for task in list(self._running_card_tasks.values()):
task.cancel()
self._running_card_tasks.clear()
if self._thread:
self._thread.join(timeout=5)
self._thread = None
logger.info("Feishu channel stopped")
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if not self._api_client:
logger.warning("[Feishu] send called but no api_client available")
return
logger.info(
"[Feishu] sending reply: chat_id=%s, thread_ts=%s, text_len=%d",
msg.chat_id,
msg.thread_ts,
len(msg.text),
)
last_exc: Exception | None = None
for attempt in range(_max_retries):
try:
await self._send_card_message(msg)
return # success
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
delay = 2**attempt # 1s, 2s
logger.warning(
"[Feishu] send failed (attempt %d/%d), retrying in %ds: %s",
attempt + 1,
_max_retries,
delay,
exc,
)
await asyncio.sleep(delay)
logger.error("[Feishu] send failed after %d attempts: %s", _max_retries, last_exc)
if last_exc is None:
raise RuntimeError("Feishu send failed without an exception from any attempt")
raise last_exc
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not self._api_client:
return False
# Check size limits (image: 10MB, file: 30MB)
if attachment.is_image and attachment.size > 10 * 1024 * 1024:
logger.warning("[Feishu] image too large (%d bytes), skipping: %s", attachment.size, attachment.filename)
return False
if not attachment.is_image and attachment.size > 30 * 1024 * 1024:
logger.warning("[Feishu] file too large (%d bytes), skipping: %s", attachment.size, attachment.filename)
return False
try:
if attachment.is_image:
file_key = await self._upload_image(attachment.actual_path)
msg_type = "image"
content = json.dumps({"image_key": file_key})
else:
file_key = await self._upload_file(attachment.actual_path, attachment.filename)
msg_type = "file"
content = json.dumps({"file_key": file_key})
if msg.thread_ts:
request = self._ReplyMessageRequest.builder().message_id(msg.thread_ts).request_body(self._ReplyMessageRequestBody.builder().msg_type(msg_type).content(content).reply_in_thread(True).build()).build()
await asyncio.to_thread(self._api_client.im.v1.message.reply, request)
else:
request = self._CreateMessageRequest.builder().receive_id_type("chat_id").request_body(self._CreateMessageRequestBody.builder().receive_id(msg.chat_id).msg_type(msg_type).content(content).build()).build()
await asyncio.to_thread(self._api_client.im.v1.message.create, request)
logger.info("[Feishu] file sent: %s (type=%s)", attachment.filename, msg_type)
return True
except Exception:
logger.exception("[Feishu] failed to upload/send file: %s", attachment.filename)
return False
async def _upload_image(self, path) -> str:
"""Upload an image to Feishu and return the image_key."""
with open(str(path), "rb") as f:
request = self._CreateImageRequest.builder().request_body(self._CreateImageRequestBody.builder().image_type("message").image(f).build()).build()
response = await asyncio.to_thread(self._api_client.im.v1.image.create, request)
if not response.success():
raise RuntimeError(f"Feishu image upload failed: code={response.code}, msg={response.msg}")
return response.data.image_key
async def _upload_file(self, path, filename: str) -> str:
"""Upload a file to Feishu and return the file_key."""
suffix = path.suffix.lower() if hasattr(path, "suffix") else ""
if suffix in (".xls", ".xlsx", ".csv"):
file_type = "xls"
elif suffix in (".ppt", ".pptx"):
file_type = "ppt"
elif suffix == ".pdf":
file_type = "pdf"
elif suffix in (".doc", ".docx"):
file_type = "doc"
else:
file_type = "stream"
with open(str(path), "rb") as f:
request = self._CreateFileRequest.builder().request_body(self._CreateFileRequestBody.builder().file_type(file_type).file_name(filename).file(f).build()).build()
response = await asyncio.to_thread(self._api_client.im.v1.file.create, request)
if not response.success():
raise RuntimeError(f"Feishu file upload failed: code={response.code}, msg={response.msg}")
return response.data.file_key
async def receive_file(self, msg: InboundMessage, thread_id: str) -> InboundMessage:
"""Download a Feishu file into the thread uploads directory.
Returns the sandbox virtual path when the image is persisted successfully.
"""
if not msg.thread_ts:
logger.warning("[Feishu] received file message without thread_ts, cannot associate with conversation: %s", msg)
return msg
files = msg.files
if not files:
logger.warning("[Feishu] received message with no files: %s", msg)
return msg
text = msg.text
for file in files:
if file.get("image_key"):
virtual_path = await self._receive_single_file(msg.thread_ts, file["image_key"], "image", thread_id)
text = text.replace("[image]", virtual_path, 1)
elif file.get("file_key"):
virtual_path = await self._receive_single_file(msg.thread_ts, file["file_key"], "file", thread_id)
text = text.replace("[file]", virtual_path, 1)
msg.text = text
return msg
async def _receive_single_file(self, message_id: str, file_key: str, type: Literal["image", "file"], thread_id: str) -> str:
request = self._GetMessageResourceRequest.builder().message_id(message_id).file_key(file_key).type(type).build()
def inner():
return self._api_client.im.v1.message_resource.get(request)
try:
response = await asyncio.to_thread(inner)
except Exception:
logger.exception("[Feishu] resource get request failed for resource_key=%s type=%s", file_key, type)
return f"Failed to obtain the [{type}]"
if not response.success():
logger.warning(
"[Feishu] resource get failed: resource_key=%s, type=%s, code=%s, msg=%s, log_id=%s ",
file_key,
type,
response.code,
response.msg,
response.get_log_id(),
)
return f"Failed to obtain the [{type}]"
image_stream = getattr(response, "file", None)
if image_stream is None:
logger.warning("[Feishu] resource get returned no file stream: resource_key=%s, type=%s", file_key, type)
return f"Failed to obtain the [{type}]"
try:
content: bytes = await asyncio.to_thread(image_stream.read)
except Exception:
logger.exception("[Feishu] failed to read resource stream: resource_key=%s, type=%s", file_key, type)
return f"Failed to obtain the [{type}]"
if not content:
logger.warning("[Feishu] empty resource content: resource_key=%s, type=%s", file_key, type)
return f"Failed to obtain the [{type}]"
paths = get_paths()
paths.ensure_thread_dirs(thread_id)
uploads_dir = paths.sandbox_uploads_dir(thread_id).resolve()
ext = "png" if type == "image" else "bin"
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"
# Sanitize filename: preserve extension, replace path chars in name part
if "." in raw_filename:
name_part, ext = raw_filename.rsplit(".", 1)
name_part = re.sub(r"[./\\]", "_", name_part)
filename = f"{name_part}.{ext}"
else:
filename = re.sub(r"[./\\]", "_", raw_filename)
resolved_target = uploads_dir / filename
def down_load():
# use thread_lock to avoid filename conflicts when writing
with self._thread_lock:
resolved_target.write_bytes(content)
try:
await asyncio.to_thread(down_load)
except Exception:
logger.exception("[Feishu] failed to persist downloaded resource: %s, type=%s", resolved_target, type)
return f"Failed to obtain the [{type}]"
virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"
try:
sandbox_provider = get_sandbox_provider()
sandbox_id = sandbox_provider.acquire(thread_id)
if sandbox_id != "local":
sandbox = sandbox_provider.get(sandbox_id)
if sandbox is None:
logger.warning("[Feishu] sandbox not found for thread_id=%s", thread_id)
return f"Failed to obtain the [{type}]"
sandbox.update_file(virtual_path, content)
except Exception:
logger.exception("[Feishu] failed to sync resource into non-local sandbox: %s", virtual_path)
return f"Failed to obtain the [{type}]"
logger.info("[Feishu] downloaded resource mapped: file_key=%s -> %s", file_key, virtual_path)
return virtual_path
# -- message formatting ------------------------------------------------
@staticmethod
def _build_card_content(text: str) -> str:
"""Build a Feishu interactive card with markdown content.
Feishu's interactive card format natively renders markdown, including
headers, bold/italic, code blocks, lists, and links.
"""
card = {
"config": {"wide_screen_mode": True, "update_multi": True},
"elements": [{"tag": "markdown", "content": text}],
}
return json.dumps(card)
# -- reaction helpers --------------------------------------------------
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None:
"""Add an emoji reaction to a message."""
if not self._api_client or not self._CreateMessageReactionRequest:
return
try:
request = self._CreateMessageReactionRequest.builder().message_id(message_id).request_body(self._CreateMessageReactionRequestBody.builder().reaction_type(self._Emoji.builder().emoji_type(emoji_type).build()).build()).build()
await asyncio.to_thread(self._api_client.im.v1.message_reaction.create, request)
logger.info("[Feishu] reaction '%s' added to message %s", emoji_type, message_id)
except Exception:
logger.exception("[Feishu] failed to add reaction '%s' to message %s", emoji_type, message_id)
async def _reply_card(self, message_id: str, text: str) -> str | None:
"""Reply with an interactive card and return the created card message ID."""
if not self._api_client:
return None
content = self._build_card_content(text)
request = self._ReplyMessageRequest.builder().message_id(message_id).request_body(self._ReplyMessageRequestBody.builder().msg_type("interactive").content(content).reply_in_thread(True).build()).build()
response = await asyncio.to_thread(self._api_client.im.v1.message.reply, request)
response_data = getattr(response, "data", None)
return getattr(response_data, "message_id", None)
async def _create_card(self, chat_id: str, text: str) -> None:
"""Create a new card message in the target chat."""
if not self._api_client:
return
content = self._build_card_content(text)
request = self._CreateMessageRequest.builder().receive_id_type("chat_id").request_body(self._CreateMessageRequestBody.builder().receive_id(chat_id).msg_type("interactive").content(content).build()).build()
await asyncio.to_thread(self._api_client.im.v1.message.create, request)
async def _update_card(self, message_id: str, text: str) -> None:
"""Patch an existing card message in place."""
if not self._api_client or not self._PatchMessageRequest:
return
content = self._build_card_content(text)
request = self._PatchMessageRequest.builder().message_id(message_id).request_body(self._PatchMessageRequestBody.builder().content(content).build()).build()
await asyncio.to_thread(self._api_client.im.v1.message.patch, request)
def _track_background_task(self, task: asyncio.Task, *, name: str, msg_id: str) -> None:
"""Keep a strong reference to fire-and-forget tasks and surface errors."""
self._background_tasks.add(task)
task.add_done_callback(lambda done_task, task_name=name, mid=msg_id: self._finalize_background_task(done_task, task_name, mid))
def _finalize_background_task(self, task: asyncio.Task, name: str, msg_id: str) -> None:
self._background_tasks.discard(task)
self._log_task_error(task, name, msg_id)
async def _create_running_card(self, source_message_id: str, text: str) -> str | None:
"""Create the running card and cache its message ID when available."""
running_card_id = await self._reply_card(source_message_id, text)
if running_card_id:
self._running_card_ids[source_message_id] = running_card_id
logger.info("[Feishu] running card created: source=%s card=%s", source_message_id, running_card_id)
else:
logger.warning("[Feishu] running card creation returned no message_id for source=%s, subsequent updates will fall back to new replies", source_message_id)
return running_card_id
def _ensure_running_card_started(self, source_message_id: str, text: str = "Working on it...") -> asyncio.Task | None:
"""Start running-card creation once per source message."""
running_card_id = self._running_card_ids.get(source_message_id)
if running_card_id:
return None
running_card_task = self._running_card_tasks.get(source_message_id)
if running_card_task:
return running_card_task
running_card_task = asyncio.create_task(self._create_running_card(source_message_id, text))
self._running_card_tasks[source_message_id] = running_card_task
running_card_task.add_done_callback(lambda done_task, mid=source_message_id: self._finalize_running_card_task(mid, done_task))
return running_card_task
def _finalize_running_card_task(self, source_message_id: str, task: asyncio.Task) -> None:
if self._running_card_tasks.get(source_message_id) is task:
self._running_card_tasks.pop(source_message_id, None)
self._log_task_error(task, "create_running_card", source_message_id)
async def _ensure_running_card(self, source_message_id: str, text: str = "Working on it...") -> str | None:
"""Ensure the in-thread running card exists and track its message ID."""
running_card_id = self._running_card_ids.get(source_message_id)
if running_card_id:
return running_card_id
running_card_task = self._ensure_running_card_started(source_message_id, text)
if running_card_task is None:
return self._running_card_ids.get(source_message_id)
return await running_card_task
async def _send_running_reply(self, message_id: str) -> None:
"""Reply to a message in-thread with a running card."""
try:
await self._ensure_running_card(message_id)
except Exception:
logger.exception("[Feishu] failed to send running reply for message %s", message_id)
async def _send_card_message(self, msg: OutboundMessage) -> None:
"""Send or update the Feishu card tied to the current request."""
source_message_id = msg.thread_ts
if source_message_id:
running_card_id = self._running_card_ids.get(source_message_id)
awaited_running_card_task = False
if not running_card_id:
running_card_task = self._running_card_tasks.get(source_message_id)
if running_card_task:
awaited_running_card_task = True
running_card_id = await running_card_task
if running_card_id:
try:
await self._update_card(running_card_id, msg.text)
except Exception:
if not msg.is_final:
raise
logger.exception(
"[Feishu] failed to patch running card %s, falling back to final reply",
running_card_id,
)
await self._reply_card(source_message_id, msg.text)
else:
logger.info("[Feishu] running card updated: source=%s card=%s", source_message_id, running_card_id)
elif msg.is_final:
await self._reply_card(source_message_id, msg.text)
elif awaited_running_card_task:
logger.warning(
"[Feishu] running card task finished without message_id for source=%s, skipping duplicate non-final creation",
source_message_id,
)
else:
await self._ensure_running_card(source_message_id, msg.text)
if msg.is_final:
self._running_card_ids.pop(source_message_id, None)
await self._add_reaction(source_message_id, "DONE")
return
await self._create_card(msg.chat_id, msg.text)
# -- internal ----------------------------------------------------------
@staticmethod
def _log_future_error(fut, name: str, msg_id: str) -> None:
"""Callback for run_coroutine_threadsafe futures to surface errors."""
try:
exc = fut.exception()
if exc:
logger.error("[Feishu] %s failed for msg_id=%s: %s", name, msg_id, exc)
except Exception:
pass
@staticmethod
def _log_task_error(task: asyncio.Task, name: str, msg_id: str) -> None:
"""Callback for background asyncio tasks to surface errors."""
try:
exc = task.exception()
if exc:
logger.error("[Feishu] %s failed for msg_id=%s: %s", name, msg_id, exc)
except asyncio.CancelledError:
logger.info("[Feishu] %s cancelled for msg_id=%s", name, msg_id)
except Exception:
pass
async def _prepare_inbound(self, msg_id: str, inbound) -> None:
"""Kick off Feishu side effects without delaying inbound dispatch."""
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
self._ensure_running_card_started(msg_id)
await self.bus.publish_inbound(inbound)
def _on_message(self, event) -> None:
"""Called by lark-oapi when a message is received (runs in lark thread)."""
try:
logger.info("[Feishu] raw event received: type=%s", type(event).__name__)
message = event.event.message
chat_id = message.chat_id
msg_id = message.message_id
sender_id = event.event.sender.sender_id.open_id
# root_id is set when the message is a reply within a Feishu thread.
# Use it as topic_id so all replies share the same DeerFlow thread.
root_id = getattr(message, "root_id", None) or None
# Parse message content
content = json.loads(message.content)
# files_list store the any-file-key in feishu messages, which can be used to download the file content later
# In Feishu channel, image_keys are independent of file_keys.
# The file_key includes files, videos, and audio, but does not include stickers.
files_list = []
if "text" in content:
# Handle plain text messages
text = content["text"]
elif "file_key" in content:
file_key = content.get("file_key")
if isinstance(file_key, str) and file_key:
files_list.append({"file_key": file_key})
text = "[file]"
else:
text = ""
elif "image_key" in content:
image_key = content.get("image_key")
if isinstance(image_key, str) and image_key:
files_list.append({"image_key": image_key})
text = "[image]"
else:
text = ""
elif "content" in content and isinstance(content["content"], list):
# Handle rich-text messages with a top-level "content" list (e.g., topic groups/posts)
text_paragraphs: list[str] = []
for paragraph in content["content"]:
if isinstance(paragraph, list):
paragraph_text_parts: list[str] = []
for element in paragraph:
if isinstance(element, dict):
# Include both normal text and @ mentions
if element.get("tag") in ("text", "at"):
text_value = element.get("text", "")
if text_value:
paragraph_text_parts.append(text_value)
elif element.get("tag") == "img":
image_key = element.get("image_key")
if isinstance(image_key, str) and image_key:
files_list.append({"image_key": image_key})
paragraph_text_parts.append("[image]")
elif element.get("tag") in ("file", "media"):
file_key = element.get("file_key")
if isinstance(file_key, str) and file_key:
files_list.append({"file_key": file_key})
paragraph_text_parts.append("[file]")
if paragraph_text_parts:
# Join text segments within a paragraph with spaces to avoid "helloworld"
text_paragraphs.append(" ".join(paragraph_text_parts))
# Join paragraphs with blank lines to preserve paragraph boundaries
text = "\n\n".join(text_paragraphs)
else:
text = ""
text = text.strip()
logger.info(
"[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, sender=%s, text=%r",
chat_id,
msg_id,
root_id,
sender_id,
text[:100] if text else "",
)
if not (text or files_list):
logger.info("[Feishu] empty text, ignoring message")
return
# Only treat known slash commands as commands; absolute paths and
# other slash-prefixed text should be handled as normal chat.
if _is_feishu_command(text):
msg_type = InboundMessageType.COMMAND
else:
msg_type = InboundMessageType.CHAT
# topic_id: use root_id for replies (same topic), msg_id for new messages (new topic)
topic_id = root_id or msg_id
inbound = self._make_inbound(
chat_id=chat_id,
user_id=sender_id,
text=text,
msg_type=msg_type,
thread_ts=msg_id,
files=files_list,
metadata={"message_id": msg_id, "root_id": root_id},
)
inbound.topic_id = topic_id
# Schedule on the async event loop
if self._main_loop and self._main_loop.is_running():
logger.info("[Feishu] publishing inbound message to bus (type=%s, msg_id=%s)", msg_type.value, msg_id)
fut = asyncio.run_coroutine_threadsafe(self._prepare_inbound(msg_id, inbound), self._main_loop)
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "prepare_inbound", mid))
else:
logger.warning("[Feishu] main loop not running, cannot publish inbound message")
except Exception:
logger.exception("[Feishu] error processing message")

View File

@@ -0,0 +1,960 @@
"""ChannelManager — consumes inbound messages and dispatches them to the DeerFlow agent via LangGraph Server."""
from __future__ import annotations
import asyncio
import logging
import mimetypes
import re
import time
from collections.abc import Awaitable, Callable, Mapping
from pathlib import Path
from typing import Any
import httpx
from langgraph_sdk.errors import ConflictError
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.store import ChannelStore
logger = logging.getLogger(__name__)
DEFAULT_LANGGRAPH_URL = "http://localhost:2024"
DEFAULT_GATEWAY_URL = "http://localhost:8001"
DEFAULT_ASSISTANT_ID = "lead_agent"
CUSTOM_AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
DEFAULT_RUN_CONFIG: dict[str, Any] = {"recursion_limit": 100}
DEFAULT_RUN_CONTEXT: dict[str, Any] = {
"thinking_enabled": True,
"is_plan_mode": False,
"subagent_enabled": False,
}
STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
CHANNEL_CAPABILITIES = {
"discord": {"supports_streaming": False},
"feishu": {"supports_streaming": True},
"slack": {"supports_streaming": False},
"telegram": {"supports_streaming": False},
"wechat": {"supports_streaming": False},
"wecom": {"supports_streaming": True},
}
InboundFileReader = Callable[[dict[str, Any], httpx.AsyncClient], Awaitable[bytes | None]]
INBOUND_FILE_READERS: dict[str, InboundFileReader] = {}
def register_inbound_file_reader(channel_name: str, reader: InboundFileReader) -> None:
INBOUND_FILE_READERS[channel_name] = reader
async def _read_http_inbound_file(file_info: dict[str, Any], client: httpx.AsyncClient) -> bytes | None:
url = file_info.get("url")
if not isinstance(url, str) or not url:
return None
resp = await client.get(url)
resp.raise_for_status()
return resp.content
async def _read_wecom_inbound_file(file_info: dict[str, Any], client: httpx.AsyncClient) -> bytes | None:
data = await _read_http_inbound_file(file_info, client)
if data is None:
return None
aeskey = file_info.get("aeskey") if isinstance(file_info.get("aeskey"), str) else None
if not aeskey:
return data
try:
from aibot.crypto_utils import decrypt_file
except Exception:
logger.exception("[Manager] failed to import WeCom decrypt_file")
return None
return decrypt_file(data, aeskey)
async def _read_wechat_inbound_file(file_info: dict[str, Any], client: httpx.AsyncClient) -> bytes | None:
raw_path = file_info.get("path")
if isinstance(raw_path, str) and raw_path.strip():
try:
return await asyncio.to_thread(Path(raw_path).read_bytes)
except OSError:
logger.exception("[Manager] failed to read WeChat inbound file from local path: %s", raw_path)
return None
full_url = file_info.get("full_url")
if isinstance(full_url, str) and full_url.strip():
return await _read_http_inbound_file({"url": full_url}, client)
return None
register_inbound_file_reader("wecom", _read_wecom_inbound_file)
register_inbound_file_reader("wechat", _read_wechat_inbound_file)
class InvalidChannelSessionConfigError(ValueError):
"""Raised when IM channel session overrides contain invalid agent config."""
def _is_thread_busy_error(exc: BaseException | None) -> bool:
if exc is None:
return False
if isinstance(exc, ConflictError):
return True
return "already running a task" in str(exc)
def _as_dict(value: Any) -> dict[str, Any]:
return dict(value) if isinstance(value, Mapping) else {}
def _merge_dicts(*layers: Any) -> dict[str, Any]:
merged: dict[str, Any] = {}
for layer in layers:
if isinstance(layer, Mapping):
merged.update(layer)
return merged
def _normalize_custom_agent_name(raw_value: str) -> str:
"""Normalize legacy channel assistant IDs into valid custom agent names."""
normalized = raw_value.strip().lower().replace("_", "-")
if not normalized:
raise InvalidChannelSessionConfigError("Channel session assistant_id is empty. Use 'lead_agent' or a valid custom agent name.")
if not CUSTOM_AGENT_NAME_PATTERN.fullmatch(normalized):
raise InvalidChannelSessionConfigError(f"Invalid channel session assistant_id {raw_value!r}. Use 'lead_agent' or a custom agent name containing only letters, digits, and hyphens.")
return normalized
def _extract_response_text(result: dict | list) -> str:
"""Extract the last AI message text from a LangGraph runs.wait result.
``runs.wait`` returns the final state dict which contains a ``messages``
list. Each message is a dict with at least ``type`` and ``content``.
Handles special cases:
- Regular AI text responses
- Clarification interrupts (``ask_clarification`` tool messages)
- AI messages with tool_calls but no text content
"""
if isinstance(result, list):
messages = result
elif isinstance(result, dict):
messages = result.get("messages", [])
else:
return ""
# Walk backwards to find usable response text, but stop at the last
# human message to avoid returning text from a previous turn.
for msg in reversed(messages):
if not isinstance(msg, dict):
continue
msg_type = msg.get("type")
# Stop at the last human message — anything before it is a previous turn
if msg_type == "human":
break
# Check for tool messages from ask_clarification (interrupt case)
if msg_type == "tool" and msg.get("name") == "ask_clarification":
content = msg.get("content", "")
if isinstance(content, str) and content:
return content
# Regular AI message with text content
if msg_type == "ai":
content = msg.get("content", "")
if isinstance(content, str) and content:
return content
# content can be a list of content blocks
if isinstance(content, list):
parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
parts.append(block.get("text", ""))
elif isinstance(block, str):
parts.append(block)
text = "".join(parts)
if text:
return text
return ""
def _extract_text_content(content: Any) -> str:
"""Extract text from a streaming payload content field."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, Mapping):
text = block.get("text")
if isinstance(text, str):
parts.append(text)
else:
nested = block.get("content")
if isinstance(nested, str):
parts.append(nested)
return "".join(parts)
if isinstance(content, Mapping):
for key in ("text", "content"):
value = content.get(key)
if isinstance(value, str):
return value
return ""
def _merge_stream_text(existing: str, chunk: str) -> str:
"""Merge either delta text or cumulative text into a single snapshot."""
if not chunk:
return existing
if not existing or chunk == existing:
return chunk or existing
if chunk.startswith(existing):
return chunk
if existing.endswith(chunk):
return existing
return existing + chunk
def _extract_stream_message_id(payload: Any, metadata: Any) -> str | None:
"""Best-effort extraction of the streamed AI message identifier."""
candidates = [payload, metadata]
if isinstance(payload, Mapping):
candidates.append(payload.get("kwargs"))
for candidate in candidates:
if not isinstance(candidate, Mapping):
continue
for key in ("id", "message_id"):
value = candidate.get(key)
if isinstance(value, str) and value:
return value
return None
def _accumulate_stream_text(
buffers: dict[str, str],
current_message_id: str | None,
event_data: Any,
) -> tuple[str | None, str | None]:
"""Convert a ``messages-tuple`` event into the latest displayable AI text."""
payload = event_data
metadata: Any = None
if isinstance(event_data, (list, tuple)):
if event_data:
payload = event_data[0]
if len(event_data) > 1:
metadata = event_data[1]
if isinstance(payload, str):
message_id = current_message_id or "__default__"
buffers[message_id] = _merge_stream_text(buffers.get(message_id, ""), payload)
return buffers[message_id], message_id
if not isinstance(payload, Mapping):
return None, current_message_id
payload_type = str(payload.get("type", "")).lower()
if "tool" in payload_type:
return None, current_message_id
text = _extract_text_content(payload.get("content"))
if not text and isinstance(payload.get("kwargs"), Mapping):
text = _extract_text_content(payload["kwargs"].get("content"))
if not text:
return None, current_message_id
message_id = _extract_stream_message_id(payload, metadata) or current_message_id or "__default__"
buffers[message_id] = _merge_stream_text(buffers.get(message_id, ""), text)
return buffers[message_id], message_id
def _extract_artifacts(result: dict | list) -> list[str]:
"""Extract artifact paths from the last AI response cycle only.
Instead of reading the full accumulated ``artifacts`` state (which contains
all artifacts ever produced in the thread), this inspects the messages after
the last human message and collects file paths from ``present_files`` tool
calls. This ensures only newly-produced artifacts are returned.
"""
if isinstance(result, list):
messages = result
elif isinstance(result, dict):
messages = result.get("messages", [])
else:
return []
artifacts: list[str] = []
for msg in reversed(messages):
if not isinstance(msg, dict):
continue
# Stop at the last human message — anything before it is a previous turn
if msg.get("type") == "human":
break
# Look for AI messages with present_files tool calls
if msg.get("type") == "ai":
for tc in msg.get("tool_calls", []):
if isinstance(tc, dict) and tc.get("name") == "present_files":
args = tc.get("args", {})
paths = args.get("filepaths", [])
if isinstance(paths, list):
artifacts.extend(p for p in paths if isinstance(p, str))
return artifacts
def _format_artifact_text(artifacts: list[str]) -> str:
"""Format artifact paths into a human-readable text block listing filenames."""
import posixpath
filenames = [posixpath.basename(p) for p in artifacts]
if len(filenames) == 1:
return f"Created File: 📎 {filenames[0]}"
return "Created Files: 📎 " + "".join(filenames)
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]:
"""Resolve virtual artifact paths to host filesystem paths with metadata.
Only paths under ``/mnt/user-data/outputs/`` are accepted; any other
virtual path is rejected with a warning to prevent exfiltrating uploads
or workspace files via IM channels.
Skips artifacts that cannot be resolved (missing files, invalid paths)
and logs warnings for them.
"""
from deerflow.config.paths import get_paths
attachments: list[ResolvedAttachment] = []
paths = get_paths()
outputs_dir = paths.sandbox_outputs_dir(thread_id).resolve()
for virtual_path in artifacts:
# Security: only allow files from the agent outputs directory
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
continue
try:
actual = paths.resolve_virtual_path(thread_id, virtual_path)
# Verify the resolved path is actually under the outputs directory
# (guards against path-traversal even after prefix check)
try:
actual.resolve().relative_to(outputs_dir)
except ValueError:
logger.warning("[Manager] artifact path escapes outputs dir: %s -> %s", virtual_path, actual)
continue
if not actual.is_file():
logger.warning("[Manager] artifact not found on disk: %s -> %s", virtual_path, actual)
continue
mime, _ = mimetypes.guess_type(str(actual))
mime = mime or "application/octet-stream"
attachments.append(
ResolvedAttachment(
virtual_path=virtual_path,
actual_path=actual,
filename=actual.name,
mime_type=mime,
size=actual.stat().st_size,
is_image=mime.startswith("image/"),
)
)
except (ValueError, OSError) as exc:
logger.warning("[Manager] failed to resolve artifact %s: %s", virtual_path, exc)
return attachments
def _prepare_artifact_delivery(
thread_id: str,
response_text: str,
artifacts: list[str],
) -> tuple[str, list[ResolvedAttachment]]:
"""Resolve attachments and append filename fallbacks to the text response."""
attachments: list[ResolvedAttachment] = []
if not artifacts:
return response_text, attachments
attachments = _resolve_attachments(thread_id, artifacts)
resolved_virtuals = {attachment.virtual_path for attachment in attachments}
unresolved = [path for path in artifacts if path not in resolved_virtuals]
if unresolved:
artifact_text = _format_artifact_text(unresolved)
response_text = (response_text + "\n\n" + artifact_text) if response_text else artifact_text
# Always include resolved attachment filenames as a text fallback so files
# remain discoverable even when the upload is skipped or fails.
if attachments:
resolved_text = _format_artifact_text([attachment.virtual_path for attachment in attachments])
response_text = (response_text + "\n\n" + resolved_text) if response_text else resolved_text
return response_text, attachments
async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dict[str, Any]]:
if not msg.files:
return []
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
uploads_dir = ensure_uploads_dir(thread_id)
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
created: list[dict[str, Any]] = []
file_reader = INBOUND_FILE_READERS.get(msg.channel_name, _read_http_inbound_file)
async with httpx.AsyncClient(timeout=httpx.Timeout(20.0)) as client:
for idx, f in enumerate(msg.files):
if not isinstance(f, dict):
continue
ftype = f.get("type") if isinstance(f.get("type"), str) else "file"
filename = f.get("filename") if isinstance(f.get("filename"), str) else ""
try:
data = await file_reader(f, client)
except Exception:
logger.exception(
"[Manager] failed to read inbound file: channel=%s, file=%s",
msg.channel_name,
f.get("url") or filename or idx,
)
continue
if data is None:
logger.warning(
"[Manager] inbound file reader returned no data: channel=%s, file=%s",
msg.channel_name,
f.get("url") or filename or idx,
)
continue
if not filename:
ext = ".bin"
if ftype == "image":
ext = ".png"
filename = f"{msg.thread_ts or 'msg'}_{idx}{ext}"
try:
safe_name = claim_unique_filename(normalize_filename(filename), seen_names)
except ValueError:
logger.warning(
"[Manager] skipping inbound file with unsafe filename: channel=%s, file=%r",
msg.channel_name,
filename,
)
continue
dest = uploads_dir / safe_name
try:
dest.write_bytes(data)
except Exception:
logger.exception("[Manager] failed to write inbound file: %s", dest)
continue
created.append(
{
"filename": safe_name,
"size": len(data),
"path": f"/mnt/user-data/uploads/{safe_name}",
"is_image": ftype == "image",
}
)
return created
def _format_uploaded_files_block(files: list[dict[str, Any]]) -> str:
lines = [
"<uploaded_files>",
"The following files were uploaded in this message:",
"",
]
if not files:
lines.append("(empty)")
else:
for f in files:
filename = f.get("filename", "")
size = int(f.get("size") or 0)
size_kb = size / 1024 if size else 0
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
path = f.get("path", "")
is_image = bool(f.get("is_image"))
file_kind = "image" if is_image else "file"
lines.append(f"- {filename} ({size_str})")
lines.append(f" Type: {file_kind}")
lines.append(f" Path: {path}")
lines.append("")
lines.append("Use `read_file` for text-based files and documents.")
lines.append("Use `view_image` for image files (jpg, jpeg, png, webp) so the model can inspect the image content.")
lines.append("</uploaded_files>")
return "\n".join(lines)
class ChannelManager:
"""Core dispatcher that bridges IM channels to the DeerFlow agent.
It reads from the MessageBus inbound queue, creates/reuses threads on
the LangGraph Server, sends messages via ``runs.wait``, and publishes
outbound responses back through the bus.
"""
def __init__(
self,
bus: MessageBus,
store: ChannelStore,
*,
max_concurrency: int = 5,
langgraph_url: str = DEFAULT_LANGGRAPH_URL,
gateway_url: str = DEFAULT_GATEWAY_URL,
assistant_id: str = DEFAULT_ASSISTANT_ID,
default_session: dict[str, Any] | None = None,
channel_sessions: dict[str, Any] | None = None,
) -> None:
self.bus = bus
self.store = store
self._max_concurrency = max_concurrency
self._langgraph_url = langgraph_url
self._gateway_url = gateway_url
self._assistant_id = assistant_id
self._default_session = _as_dict(default_session)
self._channel_sessions = dict(channel_sessions or {})
self._client = None # lazy init — langgraph_sdk async client
self._semaphore: asyncio.Semaphore | None = None
self._running = False
self._task: asyncio.Task | None = None
@staticmethod
def _channel_supports_streaming(channel_name: str) -> bool:
return CHANNEL_CAPABILITIES.get(channel_name, {}).get("supports_streaming", False)
def _resolve_session_layer(self, msg: InboundMessage) -> tuple[dict[str, Any], dict[str, Any]]:
channel_layer = _as_dict(self._channel_sessions.get(msg.channel_name))
users_layer = _as_dict(channel_layer.get("users"))
user_layer = _as_dict(users_layer.get(msg.user_id))
return channel_layer, user_layer
def _resolve_run_params(self, msg: InboundMessage, thread_id: str) -> tuple[str, dict[str, Any], dict[str, Any]]:
channel_layer, user_layer = self._resolve_session_layer(msg)
assistant_id = user_layer.get("assistant_id") or channel_layer.get("assistant_id") or self._default_session.get("assistant_id") or self._assistant_id
if not isinstance(assistant_id, str) or not assistant_id.strip():
assistant_id = self._assistant_id
run_config = _merge_dicts(
DEFAULT_RUN_CONFIG,
self._default_session.get("config"),
channel_layer.get("config"),
user_layer.get("config"),
)
run_context = _merge_dicts(
DEFAULT_RUN_CONTEXT,
self._default_session.get("context"),
channel_layer.get("context"),
user_layer.get("context"),
{"thread_id": thread_id},
)
# Custom agents are implemented as lead_agent + agent_name context.
# Keep backward compatibility for channel configs that set
# assistant_id: <custom-agent-name> by routing through lead_agent.
if assistant_id != DEFAULT_ASSISTANT_ID:
run_context.setdefault("agent_name", _normalize_custom_agent_name(assistant_id))
assistant_id = DEFAULT_ASSISTANT_ID
return assistant_id, run_config, run_context
# -- LangGraph SDK client (lazy) ----------------------------------------
def _get_client(self):
"""Return the ``langgraph_sdk`` async client, creating it on first use."""
if self._client is None:
from langgraph_sdk import get_client
self._client = get_client(url=self._langgraph_url)
return self._client
# -- lifecycle ---------------------------------------------------------
async def start(self) -> None:
"""Start the dispatch loop."""
if self._running:
return
self._running = True
self._semaphore = asyncio.Semaphore(self._max_concurrency)
self._task = asyncio.create_task(self._dispatch_loop())
logger.info("ChannelManager started (max_concurrency=%d)", self._max_concurrency)
async def stop(self) -> None:
"""Stop the dispatch loop."""
self._running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self._task = None
logger.info("ChannelManager stopped")
# -- dispatch loop -----------------------------------------------------
async def _dispatch_loop(self) -> None:
logger.info("[Manager] dispatch loop started, waiting for inbound messages")
while self._running:
try:
msg = await asyncio.wait_for(self.bus.get_inbound(), timeout=1.0)
except TimeoutError:
continue
except asyncio.CancelledError:
break
logger.info(
"[Manager] received inbound: channel=%s, chat_id=%s, type=%s, text=%r",
msg.channel_name,
msg.chat_id,
msg.msg_type.value,
msg.text[:100] if msg.text else "",
)
task = asyncio.create_task(self._handle_message(msg))
task.add_done_callback(self._log_task_error)
@staticmethod
def _log_task_error(task: asyncio.Task) -> None:
"""Surface unhandled exceptions from background tasks."""
if task.cancelled():
return
exc = task.exception()
if exc:
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
async def _handle_message(self, msg: InboundMessage) -> None:
async with self._semaphore:
try:
if msg.msg_type == InboundMessageType.COMMAND:
await self._handle_command(msg)
else:
await self._handle_chat(msg)
except InvalidChannelSessionConfigError as exc:
logger.warning(
"Invalid channel session config for %s (chat=%s): %s",
msg.channel_name,
msg.chat_id,
exc,
)
await self._send_error(msg, str(exc))
except Exception:
logger.exception(
"Error handling message from %s (chat=%s)",
msg.channel_name,
msg.chat_id,
)
await self._send_error(msg, "An internal error occurred. Please try again.")
# -- chat handling -----------------------------------------------------
async def _create_thread(self, client, msg: InboundMessage) -> str:
"""Create a new thread on the LangGraph Server and store the mapping."""
thread = await client.threads.create()
thread_id = thread["thread_id"]
self.store.set_thread_id(
msg.channel_name,
msg.chat_id,
thread_id,
topic_id=msg.topic_id,
user_id=msg.user_id,
)
logger.info("[Manager] new thread created on LangGraph Server: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
return thread_id
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
client = self._get_client()
# Look up existing DeerFlow thread.
# topic_id may be None (e.g. Telegram private chats) — the store
# handles this by using the "channel:chat_id" key without a topic suffix.
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
if thread_id:
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
# No existing thread found — create a new one
if thread_id is None:
thread_id = await self._create_thread(client, msg)
assistant_id, run_config, run_context = self._resolve_run_params(msg, thread_id)
# If the inbound message contains file attachments, let the channel
# materialize (download) them and update msg.text to include sandbox file paths.
# This enables downstream models to access user-uploaded files by path.
# Channels that do not support file download will simply return the original message.
if msg.files:
from .service import get_channel_service
service = get_channel_service()
channel = service.get_channel(msg.channel_name) if service else None
logger.info("[Manager] preparing receive file context for %d attachments", len(msg.files))
msg = await channel.receive_file(msg, thread_id) if channel else msg
if extra_context:
run_context.update(extra_context)
uploaded = await _ingest_inbound_files(thread_id, msg)
if uploaded:
msg.text = f"{_format_uploaded_files_block(uploaded)}\n\n{msg.text}".strip()
if self._channel_supports_streaming(msg.channel_name):
await self._handle_streaming_chat(
client,
msg,
thread_id,
assistant_id,
run_config,
run_context,
)
return
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
result = await client.runs.wait(
thread_id,
assistant_id,
input={"messages": [{"role": "human", "content": msg.text}]},
config=run_config,
context=run_context,
)
response_text = _extract_response_text(result)
artifacts = _extract_artifacts(result)
logger.info(
"[Manager] agent response received: thread_id=%s, response_len=%d, artifacts=%d",
thread_id,
len(response_text) if response_text else 0,
len(artifacts),
)
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
if not response_text:
if attachments:
response_text = _format_artifact_text([a.virtual_path for a in attachments])
else:
response_text = "(No response from agent)"
outbound = OutboundMessage(
channel_name=msg.channel_name,
chat_id=msg.chat_id,
thread_id=thread_id,
text=response_text,
artifacts=artifacts,
attachments=attachments,
thread_ts=msg.thread_ts,
)
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
await self.bus.publish_outbound(outbound)
async def _handle_streaming_chat(
self,
client,
msg: InboundMessage,
thread_id: str,
assistant_id: str,
run_config: dict[str, Any],
run_context: dict[str, Any],
) -> None:
logger.info("[Manager] invoking runs.stream(thread_id=%s, text=%r)", thread_id, msg.text[:100])
last_values: dict[str, Any] | list | None = None
streamed_buffers: dict[str, str] = {}
current_message_id: str | None = None
latest_text = ""
last_published_text = ""
last_publish_at = 0.0
stream_error: BaseException | None = None
try:
async for chunk in client.runs.stream(
thread_id,
assistant_id,
input={"messages": [{"role": "human", "content": msg.text}]},
config=run_config,
context=run_context,
stream_mode=["messages-tuple", "values"],
multitask_strategy="reject",
):
event = getattr(chunk, "event", "")
data = getattr(chunk, "data", None)
if event == "messages-tuple":
accumulated_text, current_message_id = _accumulate_stream_text(streamed_buffers, current_message_id, data)
if accumulated_text:
latest_text = accumulated_text
elif event == "values" and isinstance(data, (dict, list)):
last_values = data
snapshot_text = _extract_response_text(data)
if snapshot_text:
latest_text = snapshot_text
if not latest_text or latest_text == last_published_text:
continue
now = time.monotonic()
if last_published_text and now - last_publish_at < STREAM_UPDATE_MIN_INTERVAL_SECONDS:
continue
await self.bus.publish_outbound(
OutboundMessage(
channel_name=msg.channel_name,
chat_id=msg.chat_id,
thread_id=thread_id,
text=latest_text,
is_final=False,
thread_ts=msg.thread_ts,
)
)
last_published_text = latest_text
last_publish_at = now
except Exception as exc:
stream_error = exc
if _is_thread_busy_error(exc):
logger.warning("[Manager] thread busy (concurrent run rejected): thread_id=%s", thread_id)
else:
logger.exception("[Manager] streaming error: thread_id=%s", thread_id)
finally:
result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]}
response_text = _extract_response_text(result)
artifacts = _extract_artifacts(result)
response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts)
if not response_text:
if attachments:
response_text = _format_artifact_text([attachment.virtual_path for attachment in attachments])
elif stream_error:
if _is_thread_busy_error(stream_error):
response_text = THREAD_BUSY_MESSAGE
else:
response_text = "An error occurred while processing your request. Please try again."
else:
response_text = latest_text or "(No response from agent)"
logger.info(
"[Manager] streaming response completed: thread_id=%s, response_len=%d, artifacts=%d, error=%s",
thread_id,
len(response_text),
len(artifacts),
stream_error,
)
await self.bus.publish_outbound(
OutboundMessage(
channel_name=msg.channel_name,
chat_id=msg.chat_id,
thread_id=thread_id,
text=response_text,
artifacts=artifacts,
attachments=attachments,
is_final=True,
thread_ts=msg.thread_ts,
)
)
# -- command handling --------------------------------------------------
async def _handle_command(self, msg: InboundMessage) -> None:
text = msg.text.strip()
parts = text.split(maxsplit=1)
command = parts[0].lower().lstrip("/")
if command == "bootstrap":
from dataclasses import replace as _dc_replace
chat_text = parts[1] if len(parts) > 1 else "Initialize workspace"
chat_msg = _dc_replace(msg, text=chat_text, msg_type=InboundMessageType.CHAT)
await self._handle_chat(chat_msg, extra_context={"is_bootstrap": True})
return
if command == "new":
# Create a new thread on the LangGraph Server
client = self._get_client()
thread = await client.threads.create()
new_thread_id = thread["thread_id"]
self.store.set_thread_id(
msg.channel_name,
msg.chat_id,
new_thread_id,
topic_id=msg.topic_id,
user_id=msg.user_id,
)
reply = "New conversation started."
elif command == "status":
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
elif command == "models":
reply = await self._fetch_gateway("/api/models", "models")
elif command == "memory":
reply = await self._fetch_gateway("/api/memory", "memory")
elif command == "help":
reply = (
"Available commands:\n"
"/bootstrap — Start a bootstrap session (enables agent setup)\n"
"/new — Start a new conversation\n"
"/status — Show current thread info\n"
"/models — List available models\n"
"/memory — Show memory status\n"
"/help — Show this help"
)
else:
available = " | ".join(sorted(KNOWN_CHANNEL_COMMANDS))
reply = f"Unknown command: /{command}. Available commands: {available}"
outbound = OutboundMessage(
channel_name=msg.channel_name,
chat_id=msg.chat_id,
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
text=reply,
thread_ts=msg.thread_ts,
)
await self.bus.publish_outbound(outbound)
async def _fetch_gateway(self, path: str, kind: str) -> str:
"""Fetch data from the Gateway API for command responses."""
import httpx
try:
async with httpx.AsyncClient() as http:
resp = await http.get(f"{self._gateway_url}{path}", timeout=10)
resp.raise_for_status()
data = resp.json()
except Exception:
logger.exception("Failed to fetch %s from gateway", kind)
return f"Failed to fetch {kind} information."
if kind == "models":
names = [m["name"] for m in data.get("models", [])]
return ("Available models:\n" + "\n".join(f"{n}" for n in names)) if names else "No models configured."
elif kind == "memory":
facts = data.get("facts", [])
return f"Memory contains {len(facts)} fact(s)."
return str(data)
# -- error helper ------------------------------------------------------
async def _send_error(self, msg: InboundMessage, error_text: str) -> None:
outbound = OutboundMessage(
channel_name=msg.channel_name,
chat_id=msg.chat_id,
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
text=error_text,
thread_ts=msg.thread_ts,
)
await self.bus.publish_outbound(outbound)

View File

@@ -0,0 +1,173 @@
"""MessageBus — async pub/sub hub that decouples channels from the agent dispatcher."""
from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from enum import StrEnum
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Message types
# ---------------------------------------------------------------------------
class InboundMessageType(StrEnum):
"""Types of messages arriving from IM channels."""
CHAT = "chat"
COMMAND = "command"
@dataclass
class InboundMessage:
"""A message arriving from an IM channel toward the agent dispatcher.
Attributes:
channel_name: Name of the source channel (e.g. "feishu", "slack").
chat_id: Platform-specific chat/conversation identifier.
user_id: Platform-specific user identifier.
text: The message text.
msg_type: Whether this is a regular chat message or a command.
thread_ts: Optional platform thread identifier (for threaded replies).
topic_id: Conversation topic identifier used to map to a DeerFlow thread.
Messages sharing the same ``topic_id`` within a ``chat_id`` will
reuse the same DeerFlow thread. When ``None``, each message
creates a new thread (one-shot Q&A).
files: Optional list of file attachments (platform-specific dicts).
metadata: Arbitrary extra data from the channel.
created_at: Unix timestamp when the message was created.
"""
channel_name: str
chat_id: str
user_id: str
text: str
msg_type: InboundMessageType = InboundMessageType.CHAT
thread_ts: str | None = None
topic_id: str | None = None
files: list[dict[str, Any]] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=time.time)
@dataclass
class ResolvedAttachment:
"""A file attachment resolved to a host filesystem path, ready for upload.
Attributes:
virtual_path: Original virtual path (e.g. /mnt/user-data/outputs/report.pdf).
actual_path: Resolved host filesystem path.
filename: Basename of the file.
mime_type: MIME type (e.g. "application/pdf").
size: File size in bytes.
is_image: True for image/* MIME types (platforms may handle images differently).
"""
virtual_path: str
actual_path: Path
filename: str
mime_type: str
size: int
is_image: bool
@dataclass
class OutboundMessage:
"""A message from the agent dispatcher back to a channel.
Attributes:
channel_name: Target channel name (used for routing).
chat_id: Target chat/conversation identifier.
thread_id: DeerFlow thread ID that produced this response.
text: The response text.
artifacts: List of artifact paths produced by the agent.
is_final: Whether this is the final message in the response stream.
thread_ts: Optional platform thread identifier for threaded replies.
metadata: Arbitrary extra data.
created_at: Unix timestamp.
"""
channel_name: str
chat_id: str
thread_id: str
text: str
artifacts: list[str] = field(default_factory=list)
attachments: list[ResolvedAttachment] = field(default_factory=list)
is_final: bool = True
thread_ts: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=time.time)
# ---------------------------------------------------------------------------
# MessageBus
# ---------------------------------------------------------------------------
OutboundCallback = Callable[[OutboundMessage], Coroutine[Any, Any, None]]
class MessageBus:
"""Async pub/sub hub connecting channels and the agent dispatcher.
Channels publish inbound messages; the dispatcher consumes them.
The dispatcher publishes outbound messages; channels receive them
via registered callbacks.
"""
def __init__(self) -> None:
self._inbound_queue: asyncio.Queue[InboundMessage] = asyncio.Queue()
self._outbound_listeners: list[OutboundCallback] = []
# -- inbound -----------------------------------------------------------
async def publish_inbound(self, msg: InboundMessage) -> None:
"""Enqueue an inbound message from a channel."""
await self._inbound_queue.put(msg)
logger.info(
"[Bus] inbound enqueued: channel=%s, chat_id=%s, type=%s, queue_size=%d",
msg.channel_name,
msg.chat_id,
msg.msg_type.value,
self._inbound_queue.qsize(),
)
async def get_inbound(self) -> InboundMessage:
"""Block until the next inbound message is available."""
return await self._inbound_queue.get()
@property
def inbound_queue(self) -> asyncio.Queue[InboundMessage]:
return self._inbound_queue
# -- outbound ----------------------------------------------------------
def subscribe_outbound(self, callback: OutboundCallback) -> None:
"""Register an async callback for outbound messages."""
self._outbound_listeners.append(callback)
def unsubscribe_outbound(self, callback: OutboundCallback) -> None:
"""Remove a previously registered outbound callback."""
self._outbound_listeners = [cb for cb in self._outbound_listeners if cb is not callback]
async def publish_outbound(self, msg: OutboundMessage) -> None:
"""Dispatch an outbound message to all registered listeners."""
logger.info(
"[Bus] outbound dispatching: channel=%s, chat_id=%s, listeners=%d, text_len=%d",
msg.channel_name,
msg.chat_id,
len(self._outbound_listeners),
len(msg.text),
)
for callback in self._outbound_listeners:
try:
await callback(msg)
except Exception:
logger.exception("Error in outbound callback for channel=%s", msg.channel_name)

View File

@@ -0,0 +1,200 @@
"""ChannelService — manages the lifecycle of all IM channels."""
from __future__ import annotations
import logging
import os
from typing import Any
from app.channels.base import Channel
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
from app.channels.message_bus import MessageBus
from app.channels.store import ChannelStore
logger = logging.getLogger(__name__)
# Channel name → import path for lazy loading
_CHANNEL_REGISTRY: dict[str, str] = {
"discord": "app.channels.discord:DiscordChannel",
"feishu": "app.channels.feishu:FeishuChannel",
"slack": "app.channels.slack:SlackChannel",
"telegram": "app.channels.telegram:TelegramChannel",
"wechat": "app.channels.wechat:WechatChannel",
"wecom": "app.channels.wecom:WeComChannel",
}
_CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
value = config.pop(config_key, None)
if isinstance(value, str) and value.strip():
return value
env_value = os.getenv(env_key, "").strip()
if env_value:
return env_value
return default
class ChannelService:
"""Manages the lifecycle of all configured IM channels.
Reads configuration from ``config.yaml`` under the ``channels`` key,
instantiates enabled channels, and starts the ChannelManager dispatcher.
"""
def __init__(self, channels_config: dict[str, Any] | None = None) -> None:
self.bus = MessageBus()
self.store = ChannelStore()
config = dict(channels_config or {})
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
default_session = config.pop("session", None)
channel_sessions = {name: channel_config.get("session") for name, channel_config in config.items() if isinstance(channel_config, dict)}
self.manager = ChannelManager(
bus=self.bus,
store=self.store,
langgraph_url=langgraph_url,
gateway_url=gateway_url,
default_session=default_session if isinstance(default_session, dict) else None,
channel_sessions=channel_sessions,
)
self._channels: dict[str, Any] = {} # name -> Channel instance
self._config = config
self._running = False
@classmethod
def from_app_config(cls) -> ChannelService:
"""Create a ChannelService from the application config."""
from deerflow.config.app_config import get_app_config
config = get_app_config()
channels_config = {}
# extra fields are allowed by AppConfig (extra="allow")
extra = config.model_extra or {}
if "channels" in extra:
channels_config = extra["channels"]
return cls(channels_config=channels_config)
async def start(self) -> None:
"""Start the manager and all enabled channels."""
if self._running:
return
await self.manager.start()
for name, channel_config in self._config.items():
if not isinstance(channel_config, dict):
continue
if not channel_config.get("enabled", False):
logger.info("Channel %s is disabled, skipping", name)
continue
await self._start_channel(name, channel_config)
self._running = True
logger.info("ChannelService started with channels: %s", list(self._channels.keys()))
async def stop(self) -> None:
"""Stop all channels and the manager."""
for name, channel in list(self._channels.items()):
try:
await channel.stop()
logger.info("Channel %s stopped", name)
except Exception:
logger.exception("Error stopping channel %s", name)
self._channels.clear()
await self.manager.stop()
self._running = False
logger.info("ChannelService stopped")
async def restart_channel(self, name: str) -> bool:
"""Restart a specific channel. Returns True if successful."""
if name in self._channels:
try:
await self._channels[name].stop()
except Exception:
logger.exception("Error stopping channel %s for restart", name)
del self._channels[name]
config = self._config.get(name)
if not config or not isinstance(config, dict):
logger.warning("No config for channel %s", name)
return False
return await self._start_channel(name, config)
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
"""Instantiate and start a single channel."""
import_path = _CHANNEL_REGISTRY.get(name)
if not import_path:
logger.warning("Unknown channel type: %s", name)
return False
try:
from deerflow.reflection import resolve_class
channel_cls = resolve_class(import_path, base_class=None)
except Exception:
logger.exception("Failed to import channel class for %s", name)
return False
try:
channel = channel_cls(bus=self.bus, config=config)
await channel.start()
self._channels[name] = channel
logger.info("Channel %s started", name)
return True
except Exception:
logger.exception("Failed to start channel %s", name)
return False
def get_status(self) -> dict[str, Any]:
"""Return status information for all channels."""
channels_status = {}
for name in _CHANNEL_REGISTRY:
config = self._config.get(name, {})
enabled = isinstance(config, dict) and config.get("enabled", False)
running = name in self._channels and self._channels[name].is_running
channels_status[name] = {
"enabled": enabled,
"running": running,
}
return {
"service_running": self._running,
"channels": channels_status,
}
def get_channel(self, name: str) -> Channel | None:
"""Return a running channel instance by name when available."""
return self._channels.get(name)
# -- singleton access -------------------------------------------------------
_channel_service: ChannelService | None = None
def get_channel_service() -> ChannelService | None:
"""Get the singleton ChannelService instance (if started)."""
return _channel_service
async def start_channel_service() -> ChannelService:
"""Create and start the global ChannelService from app config."""
global _channel_service
if _channel_service is not None:
return _channel_service
_channel_service = ChannelService.from_app_config()
await _channel_service.start()
return _channel_service
async def stop_channel_service() -> None:
"""Stop the global ChannelService."""
global _channel_service
if _channel_service is not None:
await _channel_service.stop()
_channel_service = None

View File

@@ -0,0 +1,246 @@
"""Slack channel — connects via Socket Mode (no public IP needed)."""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from markdown_to_mrkdwn import SlackMarkdownConverter
from app.channels.base import Channel
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
_slack_md_converter = SlackMarkdownConverter()
class SlackChannel(Channel):
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
Configuration keys (in ``config.yaml`` under ``channels.slack``):
- ``bot_token``: Slack Bot User OAuth Token (xoxb-...).
- ``app_token``: Slack App-Level Token (xapp-...) for Socket Mode.
- ``allowed_users``: (optional) List of allowed Slack user IDs. Empty = allow all.
"""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
super().__init__(name="slack", bus=bus, config=config)
self._socket_client = None
self._web_client = None
self._loop: asyncio.AbstractEventLoop | None = None
self._allowed_users: set[str] = {str(user_id) for user_id in config.get("allowed_users", [])}
async def start(self) -> None:
if self._running:
return
try:
from slack_sdk import WebClient
from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.response import SocketModeResponse
except ImportError:
logger.error("slack-sdk is not installed. Install it with: uv add slack-sdk")
return
self._SocketModeResponse = SocketModeResponse
bot_token = self.config.get("bot_token", "")
app_token = self.config.get("app_token", "")
if not bot_token or not app_token:
logger.error("Slack channel requires bot_token and app_token")
return
self._web_client = WebClient(token=bot_token)
self._socket_client = SocketModeClient(
app_token=app_token,
web_client=self._web_client,
)
self._loop = asyncio.get_event_loop()
self._socket_client.socket_mode_request_listeners.append(self._on_socket_event)
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
# Start socket mode in background thread
asyncio.get_event_loop().run_in_executor(None, self._socket_client.connect)
logger.info("Slack channel started")
async def stop(self) -> None:
self._running = False
self.bus.unsubscribe_outbound(self._on_outbound)
if self._socket_client:
self._socket_client.close()
self._socket_client = None
logger.info("Slack channel stopped")
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if not self._web_client:
return
kwargs: dict[str, Any] = {
"channel": msg.chat_id,
"text": _slack_md_converter.convert(msg.text),
}
if msg.thread_ts:
kwargs["thread_ts"] = msg.thread_ts
last_exc: Exception | None = None
for attempt in range(_max_retries):
try:
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
# Add a completion reaction to the thread root
if msg.thread_ts:
await asyncio.to_thread(
self._add_reaction,
msg.chat_id,
msg.thread_ts,
"white_check_mark",
)
return
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
delay = 2**attempt # 1s, 2s
logger.warning(
"[Slack] send failed (attempt %d/%d), retrying in %ds: %s",
attempt + 1,
_max_retries,
delay,
exc,
)
await asyncio.sleep(delay)
logger.error("[Slack] send failed after %d attempts: %s", _max_retries, last_exc)
# Add failure reaction on error
if msg.thread_ts:
try:
await asyncio.to_thread(
self._add_reaction,
msg.chat_id,
msg.thread_ts,
"x",
)
except Exception:
pass
if last_exc is None:
raise RuntimeError("Slack send failed without an exception from any attempt")
raise last_exc
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not self._web_client:
return False
try:
kwargs: dict[str, Any] = {
"channel": msg.chat_id,
"file": str(attachment.actual_path),
"filename": attachment.filename,
"title": attachment.filename,
}
if msg.thread_ts:
kwargs["thread_ts"] = msg.thread_ts
await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs)
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
return True
except Exception:
logger.exception("[Slack] failed to upload file: %s", attachment.filename)
return False
# -- internal ----------------------------------------------------------
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
if not self._web_client:
return
try:
self._web_client.reactions_add(
channel=channel_id,
timestamp=timestamp,
name=emoji,
)
except Exception as exc:
if "already_reacted" not in str(exc):
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
if not self._web_client:
return
try:
self._web_client.chat_postMessage(
channel=channel_id,
text=":hourglass_flowing_sand: Working on it...",
thread_ts=thread_ts,
)
logger.info("[Slack] 'Working on it...' reply sent in channel=%s, thread_ts=%s", channel_id, thread_ts)
except Exception:
logger.exception("[Slack] failed to send running reply in channel=%s", channel_id)
def _on_socket_event(self, client, req) -> None:
"""Called by slack-sdk for each Socket Mode event."""
try:
# Acknowledge the event
response = self._SocketModeResponse(envelope_id=req.envelope_id)
client.send_socket_mode_response(response)
event_type = req.type
if event_type != "events_api":
return
event = req.payload.get("event", {})
etype = event.get("type", "")
# Handle message events (DM or @mention)
if etype in ("message", "app_mention"):
self._handle_message_event(event)
except Exception:
logger.exception("Error processing Slack event")
def _handle_message_event(self, event: dict) -> None:
# Ignore bot messages
if event.get("bot_id") or event.get("subtype"):
return
user_id = event.get("user", "")
# Check allowed users
if self._allowed_users and user_id not in self._allowed_users:
logger.debug("Ignoring message from non-allowed user: %s", user_id)
return
text = event.get("text", "").strip()
if not text:
return
channel_id = event.get("channel", "")
thread_ts = event.get("thread_ts") or event.get("ts", "")
if text.startswith("/"):
msg_type = InboundMessageType.COMMAND
else:
msg_type = InboundMessageType.CHAT
# topic_id: use thread_ts as the topic identifier.
# For threaded messages, thread_ts is the root message ts (shared topic).
# For non-threaded messages, thread_ts is the message's own ts (new topic).
inbound = self._make_inbound(
chat_id=channel_id,
user_id=user_id,
text=text,
msg_type=msg_type,
thread_ts=thread_ts,
)
inbound.topic_id = thread_ts
if self._loop and self._loop.is_running():
# Acknowledge with an eyes reaction
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
# Send "running" reply first (fire-and-forget from SDK thread)
self._send_running_reply(channel_id, thread_ts)
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)

View File

@@ -0,0 +1,153 @@
"""ChannelStore — persists IM chat-to-DeerFlow thread mappings."""
from __future__ import annotations
import json
import logging
import tempfile
import threading
import time
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
class ChannelStore:
"""JSON-file-backed store that maps IM conversations to DeerFlow threads.
Data layout (on disk)::
{
"<channel_name>:<chat_id>": {
"thread_id": "<uuid>",
"user_id": "<platform_user>",
"created_at": 1700000000.0,
"updated_at": 1700000000.0
},
...
}
The store is intentionally simple — a single JSON file that is atomically
rewritten on every mutation. For production workloads with high concurrency,
this can be swapped for a proper database backend.
"""
def __init__(self, path: str | Path | None = None) -> None:
if path is None:
from deerflow.config.paths import get_paths
path = Path(get_paths().base_dir) / "channels" / "store.json"
self._path = Path(path)
self._path.parent.mkdir(parents=True, exist_ok=True)
self._data: dict[str, dict[str, Any]] = self._load()
self._lock = threading.Lock()
# -- persistence -------------------------------------------------------
def _load(self) -> dict[str, dict[str, Any]]:
if self._path.exists():
try:
return json.loads(self._path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
logger.warning("Corrupt channel store at %s, starting fresh", self._path)
return {}
def _save(self) -> None:
fd = tempfile.NamedTemporaryFile(
mode="w",
dir=self._path.parent,
suffix=".tmp",
delete=False,
)
try:
json.dump(self._data, fd, indent=2)
fd.close()
Path(fd.name).replace(self._path)
except BaseException:
fd.close()
Path(fd.name).unlink(missing_ok=True)
raise
# -- key helpers -------------------------------------------------------
@staticmethod
def _key(channel_name: str, chat_id: str, topic_id: str | None = None) -> str:
if topic_id:
return f"{channel_name}:{chat_id}:{topic_id}"
return f"{channel_name}:{chat_id}"
# -- public API --------------------------------------------------------
def get_thread_id(self, channel_name: str, chat_id: str, topic_id: str | None = None) -> str | None:
"""Look up the DeerFlow thread_id for a given IM conversation/topic."""
entry = self._data.get(self._key(channel_name, chat_id, topic_id))
return entry["thread_id"] if entry else None
def set_thread_id(
self,
channel_name: str,
chat_id: str,
thread_id: str,
*,
topic_id: str | None = None,
user_id: str = "",
) -> None:
"""Create or update the mapping for an IM conversation/topic."""
with self._lock:
key = self._key(channel_name, chat_id, topic_id)
now = time.time()
existing = self._data.get(key)
self._data[key] = {
"thread_id": thread_id,
"user_id": user_id,
"created_at": existing["created_at"] if existing else now,
"updated_at": now,
}
self._save()
def remove(self, channel_name: str, chat_id: str, topic_id: str | None = None) -> bool:
"""Remove a mapping.
If ``topic_id`` is provided, only that specific conversation/topic mapping is removed.
If ``topic_id`` is omitted, all mappings whose key starts with
``"<channel_name>:<chat_id>"`` (including topic-specific ones) are removed.
Returns True if at least one mapping was removed.
"""
with self._lock:
# Remove a specific conversation/topic mapping.
if topic_id is not None:
key = self._key(channel_name, chat_id, topic_id)
if key in self._data:
del self._data[key]
self._save()
return True
return False
# Remove all mappings for this channel/chat_id (base and any topic-specific keys).
prefix = self._key(channel_name, chat_id)
keys_to_delete = [k for k in self._data if k == prefix or k.startswith(prefix + ":")]
if not keys_to_delete:
return False
for k in keys_to_delete:
del self._data[k]
self._save()
return True
def list_entries(self, channel_name: str | None = None) -> list[dict[str, Any]]:
"""List all stored mappings, optionally filtered by channel."""
results = []
for key, entry in self._data.items():
parts = key.split(":", 2)
ch = parts[0]
chat = parts[1] if len(parts) > 1 else ""
topic = parts[2] if len(parts) > 2 else None
if channel_name and ch != channel_name:
continue
item: dict[str, Any] = {"channel_name": ch, "chat_id": chat, **entry}
if topic is not None:
item["topic_id"] = topic
results.append(item)
return results

View File

@@ -0,0 +1,317 @@
"""Telegram channel — connects via long-polling (no public IP needed)."""
from __future__ import annotations
import asyncio
import logging
import threading
from typing import Any
from app.channels.base import Channel
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
class TelegramChannel(Channel):
"""Telegram bot channel using long-polling.
Configuration keys (in ``config.yaml`` under ``channels.telegram``):
- ``bot_token``: Telegram Bot API token (from @BotFather).
- ``allowed_users``: (optional) List of allowed Telegram user IDs. Empty = allow all.
"""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
super().__init__(name="telegram", bus=bus, config=config)
self._application = None
self._thread: threading.Thread | None = None
self._tg_loop: asyncio.AbstractEventLoop | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None
self._allowed_users: set[int] = set()
for uid in config.get("allowed_users", []):
try:
self._allowed_users.add(int(uid))
except (ValueError, TypeError):
pass
# chat_id -> last sent message_id for threaded replies
self._last_bot_message: dict[str, int] = {}
async def start(self) -> None:
if self._running:
return
try:
from telegram.ext import ApplicationBuilder, CommandHandler, MessageHandler, filters
except ImportError:
logger.error("python-telegram-bot is not installed. Install it with: uv add python-telegram-bot")
return
bot_token = self.config.get("bot_token", "")
if not bot_token:
logger.error("Telegram channel requires bot_token")
return
self._main_loop = asyncio.get_event_loop()
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
# Build the application
app = ApplicationBuilder().token(bot_token).build()
# Command handlers
app.add_handler(CommandHandler("start", self._cmd_start))
app.add_handler(CommandHandler("new", self._cmd_generic))
app.add_handler(CommandHandler("status", self._cmd_generic))
app.add_handler(CommandHandler("models", self._cmd_generic))
app.add_handler(CommandHandler("memory", self._cmd_generic))
app.add_handler(CommandHandler("help", self._cmd_generic))
# General message handler
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self._on_text))
self._application = app
# Run polling in a dedicated thread with its own event loop
self._thread = threading.Thread(target=self._run_polling, daemon=True)
self._thread.start()
logger.info("Telegram channel started")
async def stop(self) -> None:
self._running = False
self.bus.unsubscribe_outbound(self._on_outbound)
if self._tg_loop and self._tg_loop.is_running():
self._tg_loop.call_soon_threadsafe(self._tg_loop.stop)
if self._thread:
self._thread.join(timeout=10)
self._thread = None
self._application = None
logger.info("Telegram channel stopped")
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if not self._application:
return
try:
chat_id = int(msg.chat_id)
except (ValueError, TypeError):
logger.error("Invalid Telegram chat_id: %s", msg.chat_id)
return
kwargs: dict[str, Any] = {"chat_id": chat_id, "text": msg.text}
# Reply to the last bot message in this chat for threading
reply_to = self._last_bot_message.get(msg.chat_id)
if reply_to:
kwargs["reply_to_message_id"] = reply_to
bot = self._application.bot
last_exc: Exception | None = None
for attempt in range(_max_retries):
try:
sent = await bot.send_message(**kwargs)
self._last_bot_message[msg.chat_id] = sent.message_id
return
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
delay = 2**attempt # 1s, 2s
logger.warning(
"[Telegram] send failed (attempt %d/%d), retrying in %ds: %s",
attempt + 1,
_max_retries,
delay,
exc,
)
await asyncio.sleep(delay)
logger.error("[Telegram] send failed after %d attempts: %s", _max_retries, last_exc)
if last_exc is None:
raise RuntimeError("Telegram send failed without an exception from any attempt")
raise last_exc
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not self._application:
return False
try:
chat_id = int(msg.chat_id)
except (ValueError, TypeError):
logger.error("[Telegram] Invalid chat_id: %s", msg.chat_id)
return False
# Telegram limits: 10MB for photos, 50MB for documents
if attachment.size > 50 * 1024 * 1024:
logger.warning("[Telegram] file too large (%d bytes), skipping: %s", attachment.size, attachment.filename)
return False
bot = self._application.bot
reply_to = self._last_bot_message.get(msg.chat_id)
try:
if attachment.is_image and attachment.size <= 10 * 1024 * 1024:
with open(attachment.actual_path, "rb") as f:
kwargs: dict[str, Any] = {"chat_id": chat_id, "photo": f}
if reply_to:
kwargs["reply_to_message_id"] = reply_to
sent = await bot.send_photo(**kwargs)
else:
from telegram import InputFile
with open(attachment.actual_path, "rb") as f:
input_file = InputFile(f, filename=attachment.filename)
kwargs = {"chat_id": chat_id, "document": input_file}
if reply_to:
kwargs["reply_to_message_id"] = reply_to
sent = await bot.send_document(**kwargs)
self._last_bot_message[msg.chat_id] = sent.message_id
logger.info("[Telegram] file sent: %s to chat=%s", attachment.filename, msg.chat_id)
return True
except Exception:
logger.exception("[Telegram] failed to send file: %s", attachment.filename)
return False
# -- helpers -----------------------------------------------------------
async def _send_running_reply(self, chat_id: str, reply_to_message_id: int) -> None:
"""Send a 'Working on it...' reply to the user's message."""
if not self._application:
return
try:
bot = self._application.bot
await bot.send_message(
chat_id=int(chat_id),
text="Working on it...",
reply_to_message_id=reply_to_message_id,
)
logger.info("[Telegram] 'Working on it...' reply sent in chat=%s", chat_id)
except Exception:
logger.exception("[Telegram] failed to send running reply in chat=%s", chat_id)
# -- internal ----------------------------------------------------------
@staticmethod
def _log_future_error(fut, name: str, msg_id: str):
try:
exc = fut.exception()
if exc:
logger.error("[Telegram] %s failed for msg_id=%s: %s", name, msg_id, exc)
except Exception:
logger.exception("[Telegram] Failed to inspect future for %s (msg_id=%s)", name, msg_id)
def _run_polling(self) -> None:
"""Run telegram polling in a dedicated thread."""
self._tg_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._tg_loop)
try:
# Cannot use run_polling() because it calls add_signal_handler(),
# which only works in the main thread. Instead, manually
# initialize the application and start the updater.
self._tg_loop.run_until_complete(self._application.initialize())
self._tg_loop.run_until_complete(self._application.start())
self._tg_loop.run_until_complete(self._application.updater.start_polling())
self._tg_loop.run_forever()
except Exception:
if self._running:
logger.exception("Telegram polling error")
finally:
# Graceful shutdown
try:
if self._application.updater.running:
self._tg_loop.run_until_complete(self._application.updater.stop())
self._tg_loop.run_until_complete(self._application.stop())
self._tg_loop.run_until_complete(self._application.shutdown())
except Exception:
logger.exception("Error during Telegram shutdown")
def _check_user(self, user_id: int) -> bool:
if not self._allowed_users:
return True
return user_id in self._allowed_users
async def _cmd_start(self, update, context) -> None:
"""Handle /start command."""
if not self._check_user(update.effective_user.id):
return
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
await self._send_running_reply(chat_id, msg_id)
await self.bus.publish_inbound(inbound)
async def _cmd_generic(self, update, context) -> None:
"""Forward slash commands to the channel manager."""
if not self._check_user(update.effective_user.id):
return
text = update.message.text
chat_id = str(update.effective_chat.id)
user_id = str(update.effective_user.id)
msg_id = str(update.message.message_id)
# Use the same topic_id logic as _on_text so that commands
# like /new target the correct thread mapping.
if update.effective_chat.type == "private":
topic_id = None
else:
reply_to = update.message.reply_to_message
if reply_to:
topic_id = str(reply_to.message_id)
else:
topic_id = msg_id
inbound = self._make_inbound(
chat_id=chat_id,
user_id=user_id,
text=text,
msg_type=InboundMessageType.COMMAND,
thread_ts=msg_id,
)
inbound.topic_id = topic_id
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
fut.add_done_callback(lambda f: self._log_future_error(f, "process_incoming_with_reply", update.message.message_id))
else:
logger.warning("[Telegram] Main loop not running. Cannot publish inbound message.")
async def _on_text(self, update, context) -> None:
"""Handle regular text messages."""
if not self._check_user(update.effective_user.id):
return
text = update.message.text.strip()
if not text:
return
chat_id = str(update.effective_chat.id)
user_id = str(update.effective_user.id)
msg_id = str(update.message.message_id)
# topic_id determines which DeerFlow thread the message maps to.
# In private chats, use None so that all messages share a single
# thread (the store key becomes "channel:chat_id").
# In group chats, use the reply-to message id or the current
# message id to keep separate conversation threads.
if update.effective_chat.type == "private":
topic_id = None
else:
reply_to = update.message.reply_to_message
if reply_to:
topic_id = str(reply_to.message_id)
else:
topic_id = msg_id
inbound = self._make_inbound(
chat_id=chat_id,
user_id=user_id,
text=text,
msg_type=InboundMessageType.CHAT,
thread_ts=msg_id,
)
inbound.topic_id = topic_id
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
fut.add_done_callback(lambda f: self._log_future_error(f, "process_incoming_with_reply", update.message.message_id))
else:
logger.warning("[Telegram] Main loop not running. Cannot publish inbound message.")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,394 @@
from __future__ import annotations
import asyncio
import base64
import hashlib
import logging
from collections.abc import Awaitable, Callable
from typing import Any, cast
from app.channels.base import Channel
from app.channels.message_bus import (
InboundMessageType,
MessageBus,
OutboundMessage,
ResolvedAttachment,
)
logger = logging.getLogger(__name__)
class WeComChannel(Channel):
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
super().__init__(name="wecom", bus=bus, config=config)
self._bot_id: str | None = None
self._bot_secret: str | None = None
self._ws_client = None
self._ws_task: asyncio.Task | None = None
self._ws_frames: dict[str, dict[str, Any]] = {}
self._ws_stream_ids: dict[str, str] = {}
self._working_message = "Working on it..."
def _clear_ws_context(self, thread_ts: str | None) -> None:
if not thread_ts:
return
self._ws_frames.pop(thread_ts, None)
self._ws_stream_ids.pop(thread_ts, None)
async def _send_ws_upload_command(self, req_id: str, body: dict[str, Any], cmd: str) -> dict[str, Any]:
if not self._ws_client:
raise RuntimeError("WeCom WebSocket client is not available")
ws_manager = getattr(self._ws_client, "_ws_manager", None)
send_reply = getattr(ws_manager, "send_reply", None)
if not callable(send_reply):
raise RuntimeError("Installed wecom-aibot-python-sdk does not expose the WebSocket media upload API expected by DeerFlow. Use wecom-aibot-python-sdk==0.1.6 or update the adapter.")
send_reply_async = cast(Callable[[str, dict[str, Any], str], Awaitable[dict[str, Any]]], send_reply)
return await send_reply_async(req_id, body, cmd)
async def start(self) -> None:
if self._running:
return
bot_id = self.config.get("bot_id")
bot_secret = self.config.get("bot_secret")
working_message = self.config.get("working_message")
self._bot_id = bot_id if isinstance(bot_id, str) and bot_id else None
self._bot_secret = bot_secret if isinstance(bot_secret, str) and bot_secret else None
self._working_message = working_message if isinstance(working_message, str) and working_message else "Working on it..."
if not self._bot_id or not self._bot_secret:
logger.error("WeCom channel requires bot_id and bot_secret")
return
try:
from aibot import WSClient, WSClientOptions
except ImportError:
logger.error("wecom-aibot-python-sdk is not installed. Install it with: uv add wecom-aibot-python-sdk")
return
else:
self._ws_client = WSClient(WSClientOptions(bot_id=self._bot_id, secret=self._bot_secret, logger=logger))
self._ws_client.on("message.text", self._on_ws_text)
self._ws_client.on("message.mixed", self._on_ws_mixed)
self._ws_client.on("message.image", self._on_ws_image)
self._ws_client.on("message.file", self._on_ws_file)
self._ws_task = asyncio.create_task(self._ws_client.connect())
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
logger.info("WeCom channel started")
async def stop(self) -> None:
self._running = False
self.bus.unsubscribe_outbound(self._on_outbound)
if self._ws_task:
try:
self._ws_task.cancel()
except Exception:
pass
self._ws_task = None
if self._ws_client:
try:
self._ws_client.disconnect()
except Exception:
pass
self._ws_client = None
self._ws_frames.clear()
self._ws_stream_ids.clear()
logger.info("WeCom channel stopped")
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if self._ws_client:
await self._send_ws(msg, _max_retries=_max_retries)
return
logger.warning("[WeCom] send called but WebSocket client is not available")
async def _on_outbound(self, msg: OutboundMessage) -> None:
if msg.channel_name != self.name:
return
try:
await self.send(msg)
except Exception:
logger.exception("Failed to send outbound message on channel %s", self.name)
if msg.is_final:
self._clear_ws_context(msg.thread_ts)
return
for attachment in msg.attachments:
try:
success = await self.send_file(msg, attachment)
if not success:
logger.warning("[%s] file upload skipped for %s", self.name, attachment.filename)
except Exception:
logger.exception("[%s] failed to upload file %s", self.name, attachment.filename)
if msg.is_final:
self._clear_ws_context(msg.thread_ts)
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not msg.is_final:
return True
if not self._ws_client:
return False
if not msg.thread_ts:
return False
frame = self._ws_frames.get(msg.thread_ts)
if not frame:
return False
media_type = "image" if attachment.is_image else "file"
size_limit = 2 * 1024 * 1024 if attachment.is_image else 20 * 1024 * 1024
if attachment.size > size_limit:
logger.warning(
"[WeCom] %s too large (%d bytes), skipping: %s",
media_type,
attachment.size,
attachment.filename,
)
return False
try:
media_id = await self._upload_media_ws(
media_type=media_type,
filename=attachment.filename,
path=str(attachment.actual_path),
size=attachment.size,
)
if not media_id:
return False
body = {media_type: {"media_id": media_id}, "msgtype": media_type}
await self._ws_client.reply(frame, body)
logger.debug("[WeCom] %s sent via ws: %s", media_type, attachment.filename)
return True
except Exception:
logger.exception("[WeCom] failed to upload/send file via ws: %s", attachment.filename)
return False
async def _on_ws_text(self, frame: dict[str, Any]) -> None:
body = frame.get("body", {}) or {}
text = ((body.get("text") or {}).get("content") or "").strip()
quote = body.get("quote", {}).get("text", {}).get("content", "").strip()
if not text and not quote:
return
await self._publish_ws_inbound(frame, text + (f"\nQuote message: {quote}" if quote else ""))
async def _on_ws_mixed(self, frame: dict[str, Any]) -> None:
body = frame.get("body", {}) or {}
mixed = body.get("mixed") or {}
items = mixed.get("msg_item") or []
parts: list[str] = []
files: list[dict[str, Any]] = []
for item in items:
item_type = (item or {}).get("msgtype")
if item_type == "text":
content = (((item or {}).get("text") or {}).get("content") or "").strip()
if content:
parts.append(content)
elif item_type in ("image", "file"):
payload = (item or {}).get(item_type) or {}
url = payload.get("url")
aeskey = payload.get("aeskey")
if isinstance(url, str) and url:
files.append(
{
"type": item_type,
"url": url,
"aeskey": (aeskey if isinstance(aeskey, str) and aeskey else None),
}
)
text = "\n\n".join(parts).strip()
if not text and not files:
return
if not text:
text = "receive image/file"
await self._publish_ws_inbound(frame, text, files=files)
async def _on_ws_image(self, frame: dict[str, Any]) -> None:
body = frame.get("body", {}) or {}
image = body.get("image") or {}
url = image.get("url")
aeskey = image.get("aeskey")
if not isinstance(url, str) or not url:
return
await self._publish_ws_inbound(
frame,
"receive image ",
files=[
{
"type": "image",
"url": url,
"aeskey": aeskey if isinstance(aeskey, str) and aeskey else None,
}
],
)
async def _on_ws_file(self, frame: dict[str, Any]) -> None:
body = frame.get("body", {}) or {}
file_obj = body.get("file") or {}
url = file_obj.get("url")
aeskey = file_obj.get("aeskey")
if not isinstance(url, str) or not url:
return
await self._publish_ws_inbound(
frame,
"receive file",
files=[
{
"type": "file",
"url": url,
"aeskey": aeskey if isinstance(aeskey, str) and aeskey else None,
}
],
)
async def _publish_ws_inbound(
self,
frame: dict[str, Any],
text: str,
*,
files: list[dict[str, Any]] | None = None,
) -> None:
if not self._ws_client:
return
try:
from aibot import generate_req_id
except Exception:
return
body = frame.get("body", {}) or {}
msg_id = body.get("msgid")
if not msg_id:
return
user_id = (body.get("from") or {}).get("userid")
inbound_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=user_id, # keep user's conversation in memory
user_id=user_id,
text=text,
msg_type=inbound_type,
thread_ts=msg_id,
files=files or [],
metadata={"aibotid": body.get("aibotid"), "chattype": body.get("chattype")},
)
inbound.topic_id = user_id # keep the same thread
stream_id = generate_req_id("stream")
self._ws_frames[msg_id] = frame
self._ws_stream_ids[msg_id] = stream_id
try:
await self._ws_client.reply_stream(frame, stream_id, self._working_message, False)
except Exception:
pass
await self.bus.publish_inbound(inbound)
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if not self._ws_client:
return
try:
from aibot import generate_req_id
except Exception:
generate_req_id = None
if msg.thread_ts and msg.thread_ts in self._ws_frames:
frame = self._ws_frames[msg.thread_ts]
stream_id = self._ws_stream_ids.get(msg.thread_ts)
if not stream_id and generate_req_id:
stream_id = generate_req_id("stream")
self._ws_stream_ids[msg.thread_ts] = stream_id
if not stream_id:
return
last_exc: Exception | None = None
for attempt in range(_max_retries):
try:
await self._ws_client.reply_stream(frame, stream_id, msg.text, bool(msg.is_final))
return
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
await asyncio.sleep(2**attempt)
if last_exc:
raise last_exc
body = {"msgtype": "markdown", "markdown": {"content": msg.text}}
last_exc = None
for attempt in range(_max_retries):
try:
await self._ws_client.send_message(msg.chat_id, body)
return
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
await asyncio.sleep(2**attempt)
if last_exc:
raise last_exc
async def _upload_media_ws(
self,
*,
media_type: str,
filename: str,
path: str,
size: int,
) -> str | None:
if not self._ws_client:
return None
try:
from aibot import generate_req_id
except Exception:
return None
chunk_size = 512 * 1024
total_chunks = (size + chunk_size - 1) // chunk_size
if total_chunks < 1 or total_chunks > 100:
logger.warning("[WeCom] invalid total_chunks=%d for %s", total_chunks, filename)
return None
md5_hasher = hashlib.md5()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
md5_hasher.update(chunk)
md5 = md5_hasher.hexdigest()
init_req_id = generate_req_id("aibot_upload_media_init")
init_body = {
"type": media_type,
"filename": filename,
"total_size": int(size),
"total_chunks": int(total_chunks),
"md5": md5,
}
init_ack = await self._send_ws_upload_command(init_req_id, init_body, "aibot_upload_media_init")
upload_id = (init_ack.get("body") or {}).get("upload_id")
if not upload_id:
logger.warning("[WeCom] upload init returned no upload_id: %s", init_ack)
return None
with open(path, "rb") as f:
for idx in range(total_chunks):
data = f.read(chunk_size)
if not data:
break
chunk_req_id = generate_req_id("aibot_upload_media_chunk")
chunk_body = {
"upload_id": upload_id,
"chunk_index": int(idx),
"base64_data": base64.b64encode(data).decode("utf-8"),
}
await self._send_ws_upload_command(chunk_req_id, chunk_body, "aibot_upload_media_chunk")
finish_req_id = generate_req_id("aibot_upload_media_finish")
finish_ack = await self._send_ws_upload_command(finish_req_id, {"upload_id": upload_id}, "aibot_upload_media_finish")
media_id = (finish_ack.get("body") or {}).get("media_id")
if not media_id:
logger.warning("[WeCom] upload finish returned no media_id: %s", finish_ack)
return None
return media_id

View File

@@ -0,0 +1,4 @@
from .app import app, create_app
from .config import GatewayConfig, get_gateway_config
__all__ = ["app", "create_app", "GatewayConfig", "get_gateway_config"]

View File

@@ -0,0 +1,221 @@
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI
from app.gateway.config import get_gateway_config
from app.gateway.deps import langgraph_runtime
from app.gateway.routers import (
agents,
artifacts,
assistants_compat,
channels,
mcp,
memory,
models,
runs,
skills,
suggestions,
thread_runs,
threads,
uploads,
)
from deerflow.config.app_config import get_app_config
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan handler."""
# Load config and check necessary environment variables at startup
try:
get_app_config()
logger.info("Configuration loaded successfully")
except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}"
logger.exception(error_msg)
raise RuntimeError(error_msg) from e
config = get_gateway_config()
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
async with langgraph_runtime(app):
logger.info("LangGraph runtime initialised")
# Start IM channel service if any channels are configured
try:
from app.channels.service import start_channel_service
channel_service = await start_channel_service()
logger.info("Channel service started: %s", channel_service.get_status())
except Exception:
logger.exception("No IM channels configured or channel service failed to start")
yield
# Stop channel service on shutdown
try:
from app.channels.service import stop_channel_service
await stop_channel_service()
except Exception:
logger.exception("Failed to stop channel service")
logger.info("Shutting down API Gateway")
def create_app() -> FastAPI:
"""Create and configure the FastAPI application.
Returns:
Configured FastAPI application instance.
"""
app = FastAPI(
title="DeerFlow API Gateway",
description="""
## DeerFlow API Gateway
API Gateway for DeerFlow - A LangGraph-based AI agent backend with sandbox execution capabilities.
### Features
- **Models Management**: Query and retrieve available AI models
- **MCP Configuration**: Manage Model Context Protocol (MCP) server configurations
- **Memory Management**: Access and manage global memory data for personalized conversations
- **Skills Management**: Query and manage skills and their enabled status
- **Artifacts**: Access thread artifacts and generated files
- **Health Monitoring**: System health check endpoints
### Architecture
LangGraph requests are handled by nginx reverse proxy.
This gateway provides custom endpoints for models, MCP configuration, skills, and artifacts.
""",
version="0.1.0",
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
openapi_tags=[
{
"name": "models",
"description": "Operations for querying available AI models and their configurations",
},
{
"name": "mcp",
"description": "Manage Model Context Protocol (MCP) server configurations",
},
{
"name": "memory",
"description": "Access and manage global memory data for personalized conversations",
},
{
"name": "skills",
"description": "Manage skills and their configurations",
},
{
"name": "artifacts",
"description": "Access and download thread artifacts and generated files",
},
{
"name": "uploads",
"description": "Upload and manage user files for threads",
},
{
"name": "threads",
"description": "Manage DeerFlow thread-local filesystem data",
},
{
"name": "agents",
"description": "Create and manage custom agents with per-agent config and prompts",
},
{
"name": "suggestions",
"description": "Generate follow-up question suggestions for conversations",
},
{
"name": "channels",
"description": "Manage IM channel integrations (Feishu, Slack, Telegram)",
},
{
"name": "assistants-compat",
"description": "LangGraph Platform-compatible assistants API (stub)",
},
{
"name": "runs",
"description": "LangGraph Platform-compatible runs lifecycle (create, stream, cancel)",
},
{
"name": "health",
"description": "Health check and system status endpoints",
},
],
)
# CORS is handled by nginx - no need for FastAPI middleware
# Include routers
# Models API is mounted at /api/models
app.include_router(models.router)
# MCP API is mounted at /api/mcp
app.include_router(mcp.router)
# Memory API is mounted at /api/memory
app.include_router(memory.router)
# Skills API is mounted at /api/skills
app.include_router(skills.router)
# Artifacts API is mounted at /api/threads/{thread_id}/artifacts
app.include_router(artifacts.router)
# Uploads API is mounted at /api/threads/{thread_id}/uploads
app.include_router(uploads.router)
# Thread cleanup API is mounted at /api/threads/{thread_id}
app.include_router(threads.router)
# Agents API is mounted at /api/agents
app.include_router(agents.router)
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
app.include_router(suggestions.router)
# Channels API is mounted at /api/channels
app.include_router(channels.router)
# Assistants compatibility API (LangGraph Platform stub)
app.include_router(assistants_compat.router)
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
app.include_router(thread_runs.router)
# Stateless Runs API (stream/wait without a pre-existing thread)
app.include_router(runs.router)
@app.get("/health", tags=["health"])
async def health_check() -> dict:
"""Health check endpoint.
Returns:
Service health status information.
"""
return {"status": "healthy", "service": "deer-flow-gateway"}
return app
# Create app instance for uvicorn
app = create_app()

View File

@@ -0,0 +1,27 @@
import os
from pydantic import BaseModel, Field
class GatewayConfig(BaseModel):
"""Configuration for the API Gateway."""
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
port: int = Field(default=8001, description="Port to bind the gateway server")
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins")
_gateway_config: GatewayConfig | None = None
def get_gateway_config() -> GatewayConfig:
"""Get gateway config, loading from environment if available."""
global _gateway_config
if _gateway_config is None:
cors_origins_str = os.getenv("CORS_ORIGINS", "http://localhost:3000")
_gateway_config = GatewayConfig(
host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
port=int(os.getenv("GATEWAY_PORT", "8001")),
cors_origins=cors_origins_str.split(","),
)
return _gateway_config

View File

@@ -0,0 +1,70 @@
"""Centralized accessors for singleton objects stored on ``app.state``.
**Getters** (used by routers): raise 503 when a required dependency is
missing, except ``get_store`` which returns ``None``.
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import AsyncExitStack, asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from deerflow.runtime import RunManager, StreamBridge
@asynccontextmanager
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
"""Bootstrap and tear down all LangGraph runtime singletons.
Usage in ``app.py``::
async with langgraph_runtime(app):
yield
"""
from deerflow.agents.checkpointer.async_provider import make_checkpointer
from deerflow.runtime import make_store, make_stream_bridge
async with AsyncExitStack() as stack:
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
app.state.store = await stack.enter_async_context(make_store())
app.state.run_manager = RunManager()
yield
# ---------------------------------------------------------------------------
# Getters called by routers per-request
# ---------------------------------------------------------------------------
def get_stream_bridge(request: Request) -> StreamBridge:
"""Return the global :class:`StreamBridge`, or 503."""
bridge = getattr(request.app.state, "stream_bridge", None)
if bridge is None:
raise HTTPException(status_code=503, detail="Stream bridge not available")
return bridge
def get_run_manager(request: Request) -> RunManager:
"""Return the global :class:`RunManager`, or 503."""
mgr = getattr(request.app.state, "run_manager", None)
if mgr is None:
raise HTTPException(status_code=503, detail="Run manager not available")
return mgr
def get_checkpointer(request: Request):
"""Return the global checkpointer, or 503."""
cp = getattr(request.app.state, "checkpointer", None)
if cp is None:
raise HTTPException(status_code=503, detail="Checkpointer not available")
return cp
def get_store(request: Request):
"""Return the global store (may be ``None`` if not configured)."""
return getattr(request.app.state, "store", None)

View File

@@ -0,0 +1,28 @@
"""Shared path resolution for thread virtual paths (e.g. mnt/user-data/outputs/...)."""
from pathlib import Path
from fastapi import HTTPException
from deerflow.config.paths import get_paths
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
"""Resolve a virtual path to the actual filesystem path under thread user-data.
Args:
thread_id: The thread ID.
virtual_path: The virtual path as seen inside the sandbox
(e.g., /mnt/user-data/outputs/file.txt).
Returns:
The resolved filesystem path.
Raises:
HTTPException: If the path is invalid or outside allowed directories.
"""
try:
return get_paths().resolve_virtual_path(thread_id, virtual_path)
except ValueError as e:
status = 403 if "traversal" in str(e) else 400
raise HTTPException(status_code=status, detail=str(e))

View File

@@ -0,0 +1,3 @@
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]

View File

@@ -0,0 +1,383 @@
"""CRUD API for custom agents."""
import logging
import re
import shutil
import yaml
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["agents"])
AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
class AgentResponse(BaseModel):
"""Response model for a custom agent."""
name: str = Field(..., description="Agent name (hyphen-case)")
description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
soul: str | None = Field(default=None, description="SOUL.md content")
class AgentsListResponse(BaseModel):
"""Response model for listing all custom agents."""
agents: list[AgentResponse]
class AgentCreateRequest(BaseModel):
"""Request body for creating a custom agent."""
name: str = Field(..., description="Agent name (must match ^[A-Za-z0-9-]+$, stored as lowercase)")
description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
soul: str = Field(default="", description="SOUL.md content — agent personality and behavioral guardrails")
class AgentUpdateRequest(BaseModel):
"""Request body for updating a custom agent."""
description: str | None = Field(default=None, description="Updated description")
model: str | None = Field(default=None, description="Updated model override")
tool_groups: list[str] | None = Field(default=None, description="Updated tool group whitelist")
soul: str | None = Field(default=None, description="Updated SOUL.md content")
def _validate_agent_name(name: str) -> None:
"""Validate agent name against allowed pattern.
Args:
name: The agent name to validate.
Raises:
HTTPException: 422 if the name is invalid.
"""
if not AGENT_NAME_PATTERN.match(name):
raise HTTPException(
status_code=422,
detail=f"Invalid agent name '{name}'. Must match ^[A-Za-z0-9-]+$ (letters, digits, and hyphens only).",
)
def _normalize_agent_name(name: str) -> str:
"""Normalize agent name to lowercase for filesystem storage."""
return name.lower()
def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False) -> AgentResponse:
"""Convert AgentConfig to AgentResponse."""
soul: str | None = None
if include_soul:
soul = load_agent_soul(agent_cfg.name) or ""
return AgentResponse(
name=agent_cfg.name,
description=agent_cfg.description,
model=agent_cfg.model,
tool_groups=agent_cfg.tool_groups,
soul=soul,
)
@router.get(
"/agents",
response_model=AgentsListResponse,
summary="List Custom Agents",
description="List all custom agents available in the agents directory, including their soul content.",
)
async def list_agents() -> AgentsListResponse:
"""List all custom agents.
Returns:
List of all custom agents with their metadata and soul content.
"""
try:
agents = list_custom_agents()
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
except Exception as e:
logger.error(f"Failed to list agents: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
@router.get(
"/agents/check",
summary="Check Agent Name",
description="Validate an agent name and check if it is available (case-insensitive).",
)
async def check_agent_name(name: str) -> dict:
"""Check whether an agent name is valid and not yet taken.
Args:
name: The agent name to check.
Returns:
``{"available": true/false, "name": "<normalized>"}``
Raises:
HTTPException: 422 if the name is invalid.
"""
_validate_agent_name(name)
normalized = _normalize_agent_name(name)
available = not get_paths().agent_dir(normalized).exists()
return {"available": available, "name": normalized}
@router.get(
"/agents/{name}",
response_model=AgentResponse,
summary="Get Custom Agent",
description="Retrieve details and SOUL.md content for a specific custom agent.",
)
async def get_agent(name: str) -> AgentResponse:
"""Get a specific custom agent by name.
Args:
name: The agent name.
Returns:
Agent details including SOUL.md content.
Raises:
HTTPException: 404 if agent not found.
"""
_validate_agent_name(name)
name = _normalize_agent_name(name)
try:
agent_cfg = load_agent_config(name)
return _agent_config_to_response(agent_cfg, include_soul=True)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
except Exception as e:
logger.error(f"Failed to get agent '{name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get agent: {str(e)}")
@router.post(
"/agents",
response_model=AgentResponse,
status_code=201,
summary="Create Custom Agent",
description="Create a new custom agent with its config and SOUL.md.",
)
async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
"""Create a new custom agent.
Args:
request: The agent creation request.
Returns:
The created agent details.
Raises:
HTTPException: 409 if agent already exists, 422 if name is invalid.
"""
_validate_agent_name(request.name)
normalized_name = _normalize_agent_name(request.name)
agent_dir = get_paths().agent_dir(normalized_name)
if agent_dir.exists():
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
try:
agent_dir.mkdir(parents=True, exist_ok=True)
# Write config.yaml
config_data: dict = {"name": normalized_name}
if request.description:
config_data["description"] = request.description
if request.model is not None:
config_data["model"] = request.model
if request.tool_groups is not None:
config_data["tool_groups"] = request.tool_groups
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
yaml.dump(config_data, f, default_flow_style=False, allow_unicode=True)
# Write SOUL.md
soul_file = agent_dir / "SOUL.md"
soul_file.write_text(request.soul, encoding="utf-8")
logger.info(f"Created agent '{normalized_name}' at {agent_dir}")
agent_cfg = load_agent_config(normalized_name)
return _agent_config_to_response(agent_cfg, include_soul=True)
except HTTPException:
raise
except Exception as e:
# Clean up on failure
if agent_dir.exists():
shutil.rmtree(agent_dir)
logger.error(f"Failed to create agent '{request.name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to create agent: {str(e)}")
@router.put(
"/agents/{name}",
response_model=AgentResponse,
summary="Update Custom Agent",
description="Update an existing custom agent's config and/or SOUL.md.",
)
async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
"""Update an existing custom agent.
Args:
name: The agent name.
request: The update request (all fields optional).
Returns:
The updated agent details.
Raises:
HTTPException: 404 if agent not found.
"""
_validate_agent_name(name)
name = _normalize_agent_name(name)
try:
agent_cfg = load_agent_config(name)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
agent_dir = get_paths().agent_dir(name)
try:
# Update config if any config fields changed
config_changed = any(v is not None for v in [request.description, request.model, request.tool_groups])
if config_changed:
updated: dict = {
"name": agent_cfg.name,
"description": request.description if request.description is not None else agent_cfg.description,
}
new_model = request.model if request.model is not None else agent_cfg.model
if new_model is not None:
updated["model"] = new_model
new_tool_groups = request.tool_groups if request.tool_groups is not None else agent_cfg.tool_groups
if new_tool_groups is not None:
updated["tool_groups"] = new_tool_groups
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
yaml.dump(updated, f, default_flow_style=False, allow_unicode=True)
# Update SOUL.md if provided
if request.soul is not None:
soul_path = agent_dir / "SOUL.md"
soul_path.write_text(request.soul, encoding="utf-8")
logger.info(f"Updated agent '{name}'")
refreshed_cfg = load_agent_config(name)
return _agent_config_to_response(refreshed_cfg, include_soul=True)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to update agent '{name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to update agent: {str(e)}")
class UserProfileResponse(BaseModel):
"""Response model for the global user profile (USER.md)."""
content: str | None = Field(default=None, description="USER.md content, or null if not yet created")
class UserProfileUpdateRequest(BaseModel):
"""Request body for setting the global user profile."""
content: str = Field(default="", description="USER.md content — describes the user's background and preferences")
@router.get(
"/user-profile",
response_model=UserProfileResponse,
summary="Get User Profile",
description="Read the global USER.md file that is injected into all custom agents.",
)
async def get_user_profile() -> UserProfileResponse:
"""Return the current USER.md content.
Returns:
UserProfileResponse with content=None if USER.md does not exist yet.
"""
try:
user_md_path = get_paths().user_md_file
if not user_md_path.exists():
return UserProfileResponse(content=None)
raw = user_md_path.read_text(encoding="utf-8").strip()
return UserProfileResponse(content=raw or None)
except Exception as e:
logger.error(f"Failed to read user profile: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to read user profile: {str(e)}")
@router.put(
"/user-profile",
response_model=UserProfileResponse,
summary="Update User Profile",
description="Write the global USER.md file that is injected into all custom agents.",
)
async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileResponse:
"""Create or overwrite the global USER.md.
Args:
request: The update request with the new USER.md content.
Returns:
UserProfileResponse with the saved content.
"""
try:
paths = get_paths()
paths.base_dir.mkdir(parents=True, exist_ok=True)
paths.user_md_file.write_text(request.content, encoding="utf-8")
logger.info(f"Updated USER.md at {paths.user_md_file}")
return UserProfileResponse(content=request.content or None)
except Exception as e:
logger.error(f"Failed to update user profile: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to update user profile: {str(e)}")
@router.delete(
"/agents/{name}",
status_code=204,
summary="Delete Custom Agent",
description="Delete a custom agent and all its files (config, SOUL.md, memory).",
)
async def delete_agent(name: str) -> None:
"""Delete a custom agent.
Args:
name: The agent name.
Raises:
HTTPException: 404 if agent not found.
"""
_validate_agent_name(name)
name = _normalize_agent_name(name)
agent_dir = get_paths().agent_dir(name)
if not agent_dir.exists():
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
try:
shutil.rmtree(agent_dir)
logger.info(f"Deleted agent '{name}' from {agent_dir}")
except Exception as e:
logger.error(f"Failed to delete agent '{name}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to delete agent: {str(e)}")

View File

@@ -0,0 +1,181 @@
import logging
import mimetypes
import zipfile
from pathlib import Path
from urllib.parse import quote
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import FileResponse, PlainTextResponse, Response
from app.gateway.path_utils import resolve_thread_virtual_path
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["artifacts"])
ACTIVE_CONTENT_MIME_TYPES = {
"text/html",
"application/xhtml+xml",
"image/svg+xml",
}
def _build_content_disposition(disposition_type: str, filename: str) -> str:
"""Build an RFC 5987 encoded Content-Disposition header value."""
return f"{disposition_type}; filename*=UTF-8''{quote(filename)}"
def _build_attachment_headers(filename: str, extra_headers: dict[str, str] | None = None) -> dict[str, str]:
headers = {"Content-Disposition": _build_content_disposition("attachment", filename)}
if extra_headers:
headers.update(extra_headers)
return headers
def is_text_file_by_content(path: Path, sample_size: int = 8192) -> bool:
"""Check if file is text by examining content for null bytes."""
try:
with open(path, "rb") as f:
chunk = f.read(sample_size)
# Text files shouldn't contain null bytes
return b"\x00" not in chunk
except Exception:
return False
def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None:
"""Extract a file from a .skill ZIP archive.
Args:
zip_path: Path to the .skill file (ZIP archive).
internal_path: Path to the file inside the archive (e.g., "SKILL.md").
Returns:
The file content as bytes, or None if not found.
"""
if not zipfile.is_zipfile(zip_path):
return None
try:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
# List all files in the archive
namelist = zip_ref.namelist()
# Try direct path first
if internal_path in namelist:
return zip_ref.read(internal_path)
# Try with any top-level directory prefix (e.g., "skill-name/SKILL.md")
for name in namelist:
if name.endswith("/" + internal_path) or name == internal_path:
return zip_ref.read(name)
# Not found
return None
except (zipfile.BadZipFile, KeyError):
return None
@router.get(
"/threads/{thread_id}/artifacts/{path:path}",
summary="Get Artifact File",
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
)
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
"""Get an artifact file by its path.
The endpoint automatically detects file types and returns appropriate content types.
Use the `download` query parameter to force file download for non-active content.
Args:
thread_id: The thread ID.
path: The artifact path with virtual prefix (e.g., mnt/user-data/outputs/file.txt).
request: FastAPI request object (automatically injected).
Returns:
The file content as a FileResponse with appropriate content type:
- Active content (HTML/XHTML/SVG): Served as download attachment
- Text files: Plain text with proper MIME type
- Binary files: Inline display with download option
Raises:
HTTPException:
- 400 if path is invalid or not a file
- 403 if access denied (path traversal detected)
- 404 if file not found
Query Parameters:
download (bool): If true, forces attachment download for file types that are
otherwise returned inline or as plain text. Active HTML/XHTML/SVG content
is always downloaded regardless of this flag.
Example:
- Get text file inline: `/api/threads/abc123/artifacts/mnt/user-data/outputs/notes.txt`
- Download file: `/api/threads/abc123/artifacts/mnt/user-data/outputs/data.csv?download=true`
- Active web content such as `.html`, `.xhtml`, and `.svg` artifacts is always downloaded
"""
# Check if this is a request for a file inside a .skill archive (e.g., xxx.skill/SKILL.md)
if ".skill/" in path:
# Split the path at ".skill/" to get the ZIP file path and internal path
skill_marker = ".skill/"
marker_pos = path.find(skill_marker)
skill_file_path = path[: marker_pos + len(".skill")] # e.g., "mnt/user-data/outputs/my-skill.skill"
internal_path = path[marker_pos + len(skill_marker) :] # e.g., "SKILL.md"
actual_skill_path = resolve_thread_virtual_path(thread_id, skill_file_path)
if not actual_skill_path.exists():
raise HTTPException(status_code=404, detail=f"Skill file not found: {skill_file_path}")
if not actual_skill_path.is_file():
raise HTTPException(status_code=400, detail=f"Path is not a file: {skill_file_path}")
# Extract the file from the .skill archive
content = _extract_file_from_skill_archive(actual_skill_path, internal_path)
if content is None:
raise HTTPException(status_code=404, detail=f"File '{internal_path}' not found in skill archive")
# Determine MIME type based on the internal file
mime_type, _ = mimetypes.guess_type(internal_path)
# Add cache headers to avoid repeated ZIP extraction (cache for 5 minutes)
cache_headers = {"Cache-Control": "private, max-age=300"}
download_name = Path(internal_path).name or actual_skill_path.stem
if download or mime_type in ACTIVE_CONTENT_MIME_TYPES:
return Response(content=content, media_type=mime_type or "application/octet-stream", headers=_build_attachment_headers(download_name, cache_headers))
if mime_type and mime_type.startswith("text/"):
return PlainTextResponse(content=content.decode("utf-8"), media_type=mime_type, headers=cache_headers)
# Default to plain text for unknown types that look like text
try:
return PlainTextResponse(content=content.decode("utf-8"), media_type="text/plain", headers=cache_headers)
except UnicodeDecodeError:
return Response(content=content, media_type=mime_type or "application/octet-stream", headers=cache_headers)
actual_path = resolve_thread_virtual_path(thread_id, path)
logger.info(f"Resolving artifact path: thread_id={thread_id}, requested_path={path}, actual_path={actual_path}")
if not actual_path.exists():
raise HTTPException(status_code=404, detail=f"Artifact not found: {path}")
if not actual_path.is_file():
raise HTTPException(status_code=400, detail=f"Path is not a file: {path}")
mime_type, _ = mimetypes.guess_type(actual_path)
if download:
return FileResponse(path=actual_path, filename=actual_path.name, media_type=mime_type, headers=_build_attachment_headers(actual_path.name))
# Always force download for active content types to prevent script execution
# in the application origin when users open generated artifacts.
if mime_type in ACTIVE_CONTENT_MIME_TYPES:
return FileResponse(path=actual_path, filename=actual_path.name, media_type=mime_type, headers=_build_attachment_headers(actual_path.name))
if mime_type and mime_type.startswith("text/"):
return PlainTextResponse(content=actual_path.read_text(encoding="utf-8"), media_type=mime_type)
if is_text_file_by_content(actual_path):
return PlainTextResponse(content=actual_path.read_text(encoding="utf-8"), media_type=mime_type)
return Response(content=actual_path.read_bytes(), media_type=mime_type, headers={"Content-Disposition": _build_content_disposition("inline", actual_path.name)})

View File

@@ -0,0 +1,149 @@
"""Assistants compatibility endpoints.
Provides LangGraph Platform-compatible assistants API backed by the
``langgraph.json`` graph registry and ``config.yaml`` agent definitions.
This is a minimal stub that satisfies the ``useStream`` React hook's
initialization requirements (``assistants.search()`` and ``assistants.get()``).
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime
from typing import Any
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/assistants", tags=["assistants-compat"])
class AssistantResponse(BaseModel):
assistant_id: str
graph_id: str
name: str
config: dict[str, Any] = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
description: str | None = None
created_at: str = ""
updated_at: str = ""
version: int = 1
class AssistantSearchRequest(BaseModel):
graph_id: str | None = None
name: str | None = None
metadata: dict[str, Any] | None = None
limit: int = 10
offset: int = 0
def _get_default_assistant() -> AssistantResponse:
"""Return the default lead_agent assistant."""
now = datetime.now(UTC).isoformat()
return AssistantResponse(
assistant_id="lead_agent",
graph_id="lead_agent",
name="lead_agent",
config={},
metadata={"created_by": "system"},
description="DeerFlow lead agent",
created_at=now,
updated_at=now,
version=1,
)
def _list_assistants() -> list[AssistantResponse]:
"""List all available assistants from config."""
assistants = [_get_default_assistant()]
# Also include custom agents from config.yaml agents directory
try:
from deerflow.config.agents_config import list_custom_agents
for agent_cfg in list_custom_agents():
now = datetime.now(UTC).isoformat()
assistants.append(
AssistantResponse(
assistant_id=agent_cfg.name,
graph_id="lead_agent", # All agents use the same graph
name=agent_cfg.name,
config={},
metadata={"created_by": "user"},
description=agent_cfg.description or "",
created_at=now,
updated_at=now,
version=1,
)
)
except Exception:
logger.debug("Could not load custom agents for assistants list")
return assistants
@router.post("/search", response_model=list[AssistantResponse])
async def search_assistants(body: AssistantSearchRequest | None = None) -> list[AssistantResponse]:
"""Search assistants.
Returns all registered assistants (lead_agent + custom agents from config).
"""
assistants = _list_assistants()
if body and body.graph_id:
assistants = [a for a in assistants if a.graph_id == body.graph_id]
if body and body.name:
assistants = [a for a in assistants if body.name.lower() in a.name.lower()]
offset = body.offset if body else 0
limit = body.limit if body else 10
return assistants[offset : offset + limit]
@router.get("/{assistant_id}", response_model=AssistantResponse)
async def get_assistant_compat(assistant_id: str) -> AssistantResponse:
"""Get an assistant by ID."""
for a in _list_assistants():
if a.assistant_id == assistant_id:
return a
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
@router.get("/{assistant_id}/graph")
async def get_assistant_graph(assistant_id: str) -> dict:
"""Get the graph structure for an assistant.
Returns a minimal graph description. Full graph introspection is
not supported in the Gateway — this stub satisfies SDK validation.
"""
found = any(a.assistant_id == assistant_id for a in _list_assistants())
if not found:
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
return {
"graph_id": "lead_agent",
"nodes": [],
"edges": [],
}
@router.get("/{assistant_id}/schemas")
async def get_assistant_schemas(assistant_id: str) -> dict:
"""Get JSON schemas for an assistant's input/output/state.
Returns empty schemas — full introspection not supported in Gateway.
"""
found = any(a.assistant_id == assistant_id for a in _list_assistants())
if not found:
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
return {
"graph_id": "lead_agent",
"input_schema": {},
"output_schema": {},
"state_schema": {},
"config_schema": {},
}

View File

@@ -0,0 +1,52 @@
"""Gateway router for IM channel management."""
from __future__ import annotations
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/channels", tags=["channels"])
class ChannelStatusResponse(BaseModel):
service_running: bool
channels: dict[str, dict]
class ChannelRestartResponse(BaseModel):
success: bool
message: str
@router.get("/", response_model=ChannelStatusResponse)
async def get_channels_status() -> ChannelStatusResponse:
"""Get the status of all IM channels."""
from app.channels.service import get_channel_service
service = get_channel_service()
if service is None:
return ChannelStatusResponse(service_running=False, channels={})
status = service.get_status()
return ChannelStatusResponse(**status)
@router.post("/{name}/restart", response_model=ChannelRestartResponse)
async def restart_channel(name: str) -> ChannelRestartResponse:
"""Restart a specific IM channel."""
from app.channels.service import get_channel_service
service = get_channel_service()
if service is None:
raise HTTPException(status_code=503, detail="Channel service is not running")
success = await service.restart_channel(name)
if success:
logger.info("Channel %s restarted successfully", name)
return ChannelRestartResponse(success=True, message=f"Channel {name} restarted successfully")
else:
logger.warning("Failed to restart channel %s", name)
return ChannelRestartResponse(success=False, message=f"Failed to restart channel {name}")

View File

@@ -0,0 +1,169 @@
import json
import logging
from pathlib import Path
from typing import Literal
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["mcp"])
class McpOAuthConfigResponse(BaseModel):
"""OAuth configuration for an MCP server."""
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
token_url: str = Field(default="", description="OAuth token endpoint URL")
grant_type: Literal["client_credentials", "refresh_token"] = Field(default="client_credentials", description="OAuth grant type")
client_id: str | None = Field(default=None, description="OAuth client ID")
client_secret: str | None = Field(default=None, description="OAuth client secret")
refresh_token: str | None = Field(default=None, description="OAuth refresh token")
scope: str | None = Field(default=None, description="OAuth scope")
audience: str | None = Field(default=None, description="OAuth audience")
token_field: str = Field(default="access_token", description="Token response field containing access token")
token_type_field: str = Field(default="token_type", description="Token response field containing token type")
expires_in_field: str = Field(default="expires_in", description="Token response field containing expires-in seconds")
default_token_type: str = Field(default="Bearer", description="Default token type when response omits token_type")
refresh_skew_seconds: int = Field(default=60, description="Refresh this many seconds before expiry")
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
class McpServerConfigResponse(BaseModel):
"""Response model for MCP server configuration."""
enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'")
command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)")
args: list[str] = Field(default_factory=list, description="Arguments to pass to the command (for stdio type)")
env: dict[str, str] = Field(default_factory=dict, description="Environment variables for the MCP server")
url: str | None = Field(default=None, description="URL of the MCP server (for sse or http type)")
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
oauth: McpOAuthConfigResponse | None = Field(default=None, description="OAuth configuration for MCP HTTP/SSE servers")
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
class McpConfigResponse(BaseModel):
"""Response model for MCP configuration."""
mcp_servers: dict[str, McpServerConfigResponse] = Field(
default_factory=dict,
description="Map of MCP server name to configuration",
)
class McpConfigUpdateRequest(BaseModel):
"""Request model for updating MCP configuration."""
mcp_servers: dict[str, McpServerConfigResponse] = Field(
...,
description="Map of MCP server name to configuration",
)
@router.get(
"/mcp/config",
response_model=McpConfigResponse,
summary="Get MCP Configuration",
description="Retrieve the current Model Context Protocol (MCP) server configurations.",
)
async def get_mcp_configuration() -> McpConfigResponse:
"""Get the current MCP configuration.
Returns:
The current MCP configuration with all servers.
Example:
```json
{
"mcp_servers": {
"github": {
"enabled": true,
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {"GITHUB_TOKEN": "ghp_xxx"},
"description": "GitHub MCP server for repository operations"
}
}
}
```
"""
config = get_extensions_config()
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in config.mcp_servers.items()})
@router.put(
"/mcp/config",
response_model=McpConfigResponse,
summary="Update MCP Configuration",
description="Update Model Context Protocol (MCP) server configurations and save to file.",
)
async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfigResponse:
"""Update the MCP configuration.
This will:
1. Save the new configuration to the mcp_config.json file
2. Reload the configuration cache
3. Reset MCP tools cache to trigger reinitialization
Args:
request: The new MCP configuration to save.
Returns:
The updated MCP configuration.
Raises:
HTTPException: 500 if the configuration file cannot be written.
Example Request:
```json
{
"mcp_servers": {
"github": {
"enabled": true,
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {"GITHUB_TOKEN": "$GITHUB_TOKEN"},
"description": "GitHub MCP server for repository operations"
}
}
}
```
"""
try:
# Get the current config path (or determine where to save it)
config_path = ExtensionsConfig.resolve_config_path()
# If no config file exists, create one in the parent directory (project root)
if config_path is None:
config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
# Load current config to preserve skills configuration
current_config = get_extensions_config()
# Convert request to dict format for JSON serialization
config_data = {
"mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
}
# Write the configuration to file
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
logger.info(f"MCP configuration updated and saved to: {config_path}")
# NOTE: No need to reload/reset cache here - LangGraph Server (separate process)
# will detect config file changes via mtime and reinitialize MCP tools automatically
# Reload the configuration and update the global cache
reloaded_config = reload_extensions_config()
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_config.mcp_servers.items()})
except Exception as e:
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to update MCP configuration: {str(e)}")

View File

@@ -0,0 +1,353 @@
"""Memory API router for retrieving and managing global memory data."""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.agents.memory.updater import (
clear_memory_data,
create_memory_fact,
delete_memory_fact,
get_memory_data,
import_memory_data,
reload_memory_data,
update_memory_fact,
)
from deerflow.config.memory_config import get_memory_config
router = APIRouter(prefix="/api", tags=["memory"])
class ContextSection(BaseModel):
"""Model for context sections (user and history)."""
summary: str = Field(default="", description="Summary content")
updatedAt: str = Field(default="", description="Last update timestamp")
class UserContext(BaseModel):
"""Model for user context."""
workContext: ContextSection = Field(default_factory=ContextSection)
personalContext: ContextSection = Field(default_factory=ContextSection)
topOfMind: ContextSection = Field(default_factory=ContextSection)
class HistoryContext(BaseModel):
"""Model for history context."""
recentMonths: ContextSection = Field(default_factory=ContextSection)
earlierContext: ContextSection = Field(default_factory=ContextSection)
longTermBackground: ContextSection = Field(default_factory=ContextSection)
class Fact(BaseModel):
"""Model for a memory fact."""
id: str = Field(..., description="Unique identifier for the fact")
content: str = Field(..., description="Fact content")
category: str = Field(default="context", description="Fact category")
confidence: float = Field(default=0.5, description="Confidence score (0-1)")
createdAt: str = Field(default="", description="Creation timestamp")
source: str = Field(default="unknown", description="Source thread ID")
sourceError: str | None = Field(default=None, description="Optional description of the prior mistake or wrong approach")
class MemoryResponse(BaseModel):
"""Response model for memory data."""
version: str = Field(default="1.0", description="Memory schema version")
lastUpdated: str = Field(default="", description="Last update timestamp")
user: UserContext = Field(default_factory=UserContext)
history: HistoryContext = Field(default_factory=HistoryContext)
facts: list[Fact] = Field(default_factory=list)
def _map_memory_fact_value_error(exc: ValueError) -> HTTPException:
"""Convert updater validation errors into stable API responses."""
if exc.args and exc.args[0] == "confidence":
detail = "Invalid confidence value; must be between 0 and 1."
else:
detail = "Memory fact content cannot be empty."
return HTTPException(status_code=400, detail=detail)
class FactCreateRequest(BaseModel):
"""Request model for creating a memory fact."""
content: str = Field(..., min_length=1, description="Fact content")
category: str = Field(default="context", description="Fact category")
confidence: float = Field(default=0.5, ge=0.0, le=1.0, description="Confidence score (0-1)")
class FactPatchRequest(BaseModel):
"""PATCH request model that preserves existing values for omitted fields."""
content: str | None = Field(default=None, min_length=1, description="Fact content")
category: str | None = Field(default=None, description="Fact category")
confidence: float | None = Field(default=None, ge=0.0, le=1.0, description="Confidence score (0-1)")
class MemoryConfigResponse(BaseModel):
"""Response model for memory configuration."""
enabled: bool = Field(..., description="Whether memory is enabled")
storage_path: str = Field(..., description="Path to memory storage file")
debounce_seconds: int = Field(..., description="Debounce time for memory updates")
max_facts: int = Field(..., description="Maximum number of facts to store")
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
class MemoryStatusResponse(BaseModel):
"""Response model for memory status."""
config: MemoryConfigResponse
data: MemoryResponse
@router.get(
"/memory",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Get Memory Data",
description="Retrieve the current global memory data including user context, history, and facts.",
)
async def get_memory() -> MemoryResponse:
"""Get the current global memory data.
Returns:
The current memory data with user context, history, and facts.
Example Response:
```json
{
"version": "1.0",
"lastUpdated": "2024-01-15T10:30:00Z",
"user": {
"workContext": {"summary": "Working on DeerFlow project", "updatedAt": "..."},
"personalContext": {"summary": "Prefers concise responses", "updatedAt": "..."},
"topOfMind": {"summary": "Building memory API", "updatedAt": "..."}
},
"history": {
"recentMonths": {"summary": "Recent development activities", "updatedAt": "..."},
"earlierContext": {"summary": "", "updatedAt": ""},
"longTermBackground": {"summary": "", "updatedAt": ""}
},
"facts": [
{
"id": "fact_abc123",
"content": "User prefers TypeScript over JavaScript",
"category": "preference",
"confidence": 0.9,
"createdAt": "2024-01-15T10:30:00Z",
"source": "thread_xyz"
}
]
}
```
"""
memory_data = get_memory_data()
return MemoryResponse(**memory_data)
@router.post(
"/memory/reload",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Reload Memory Data",
description="Reload memory data from the storage file, refreshing the in-memory cache.",
)
async def reload_memory() -> MemoryResponse:
"""Reload memory data from file.
This forces a reload of the memory data from the storage file,
useful when the file has been modified externally.
Returns:
The reloaded memory data.
"""
memory_data = reload_memory_data()
return MemoryResponse(**memory_data)
@router.delete(
"/memory",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Clear All Memory Data",
description="Delete all saved memory data and reset the memory structure to an empty state.",
)
async def clear_memory() -> MemoryResponse:
"""Clear all persisted memory data."""
try:
memory_data = clear_memory_data()
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
return MemoryResponse(**memory_data)
@router.post(
"/memory/facts",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Create Memory Fact",
description="Create a single saved memory fact manually.",
)
async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryResponse:
"""Create a single fact manually."""
try:
memory_data = create_memory_fact(
content=request.content,
category=request.category,
confidence=request.confidence,
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
return MemoryResponse(**memory_data)
@router.delete(
"/memory/facts/{fact_id}",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Delete Memory Fact",
description="Delete a single saved memory fact by its fact id.",
)
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
"""Delete a single fact from memory by fact id."""
try:
memory_data = delete_memory_fact(fact_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
return MemoryResponse(**memory_data)
@router.patch(
"/memory/facts/{fact_id}",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Patch Memory Fact",
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
)
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -> MemoryResponse:
"""Partially update a single fact manually."""
try:
memory_data = update_memory_fact(
fact_id=fact_id,
content=request.content,
category=request.category,
confidence=request.confidence,
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
return MemoryResponse(**memory_data)
@router.get(
"/memory/export",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Export Memory Data",
description="Export the current global memory data as JSON for backup or transfer.",
)
async def export_memory() -> MemoryResponse:
"""Export the current memory data."""
memory_data = get_memory_data()
return MemoryResponse(**memory_data)
@router.post(
"/memory/import",
response_model=MemoryResponse,
response_model_exclude_none=True,
summary="Import Memory Data",
description="Import and overwrite the current global memory data from a JSON payload.",
)
async def import_memory(request: MemoryResponse) -> MemoryResponse:
"""Import and persist memory data."""
try:
memory_data = import_memory_data(request.model_dump())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
return MemoryResponse(**memory_data)
@router.get(
"/memory/config",
response_model=MemoryConfigResponse,
summary="Get Memory Configuration",
description="Retrieve the current memory system configuration.",
)
async def get_memory_config_endpoint() -> MemoryConfigResponse:
"""Get the memory system configuration.
Returns:
The current memory configuration settings.
Example Response:
```json
{
"enabled": true,
"storage_path": ".deer-flow/memory.json",
"debounce_seconds": 30,
"max_facts": 100,
"fact_confidence_threshold": 0.7,
"injection_enabled": true,
"max_injection_tokens": 2000
}
```
"""
config = get_memory_config()
return MemoryConfigResponse(
enabled=config.enabled,
storage_path=config.storage_path,
debounce_seconds=config.debounce_seconds,
max_facts=config.max_facts,
fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens,
)
@router.get(
"/memory/status",
response_model=MemoryStatusResponse,
response_model_exclude_none=True,
summary="Get Memory Status",
description="Retrieve both memory configuration and current data in a single request.",
)
async def get_memory_status() -> MemoryStatusResponse:
"""Get the memory system status including configuration and data.
Returns:
Combined memory configuration and current data.
"""
config = get_memory_config()
memory_data = get_memory_data()
return MemoryStatusResponse(
config=MemoryConfigResponse(
enabled=config.enabled,
storage_path=config.storage_path,
debounce_seconds=config.debounce_seconds,
max_facts=config.max_facts,
fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens,
),
data=MemoryResponse(**memory_data),
)

View File

@@ -0,0 +1,116 @@
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from deerflow.config import get_app_config
router = APIRouter(prefix="/api", tags=["models"])
class ModelResponse(BaseModel):
"""Response model for model information."""
name: str = Field(..., description="Unique identifier for the model")
model: str = Field(..., description="Actual provider model identifier")
display_name: str | None = Field(None, description="Human-readable name")
description: str | None = Field(None, description="Model description")
supports_thinking: bool = Field(default=False, description="Whether model supports thinking mode")
supports_reasoning_effort: bool = Field(default=False, description="Whether model supports reasoning effort")
class ModelsListResponse(BaseModel):
"""Response model for listing all models."""
models: list[ModelResponse]
@router.get(
"/models",
response_model=ModelsListResponse,
summary="List All Models",
description="Retrieve a list of all available AI models configured in the system.",
)
async def list_models() -> ModelsListResponse:
"""List all available models from configuration.
Returns model information suitable for frontend display,
excluding sensitive fields like API keys and internal configuration.
Returns:
A list of all configured models with their metadata.
Example Response:
```json
{
"models": [
{
"name": "gpt-4",
"display_name": "GPT-4",
"description": "OpenAI GPT-4 model",
"supports_thinking": false
},
{
"name": "claude-3-opus",
"display_name": "Claude 3 Opus",
"description": "Anthropic Claude 3 Opus model",
"supports_thinking": true
}
]
}
```
"""
config = get_app_config()
models = [
ModelResponse(
name=model.name,
model=model.model,
display_name=model.display_name,
description=model.description,
supports_thinking=model.supports_thinking,
supports_reasoning_effort=model.supports_reasoning_effort,
)
for model in config.models
]
return ModelsListResponse(models=models)
@router.get(
"/models/{model_name}",
response_model=ModelResponse,
summary="Get Model Details",
description="Retrieve detailed information about a specific AI model by its name.",
)
async def get_model(model_name: str) -> ModelResponse:
"""Get a specific model by name.
Args:
model_name: The unique name of the model to retrieve.
Returns:
Model information if found.
Raises:
HTTPException: 404 if model not found.
Example Response:
```json
{
"name": "gpt-4",
"display_name": "GPT-4",
"description": "OpenAI GPT-4 model",
"supports_thinking": false
}
```
"""
config = get_app_config()
model = config.get_model_config(model_name)
if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
return ModelResponse(
name=model.name,
model=model.model,
display_name=model.display_name,
description=model.description,
supports_thinking=model.supports_thinking,
supports_reasoning_effort=model.supports_reasoning_effort,
)

View File

@@ -0,0 +1,87 @@
"""Stateless runs endpoints -- stream and wait without a pre-existing thread.
These endpoints auto-create a temporary thread when no ``thread_id`` is
supplied in the request body. When a ``thread_id`` **is** provided, it
is reused so that conversation history is preserved across calls.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
from app.gateway.routers.thread_runs import RunCreateRequest
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import serialize_channel_values
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/runs", tags=["runs"])
def _resolve_thread_id(body: RunCreateRequest) -> str:
"""Return the thread_id from the request body, or generate a new one."""
thread_id = (body.config or {}).get("configurable", {}).get("thread_id")
if thread_id:
return str(thread_id)
return str(uuid.uuid4())
@router.post("/stream")
async def stateless_stream(body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE.
If ``config.configurable.thread_id`` is provided, the run is created
on the given thread so that conversation history is preserved.
Otherwise a new temporary thread is created.
"""
thread_id = _resolve_thread_id(body)
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
},
)
@router.post("/wait", response_model=dict)
async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until completion.
If ``config.configurable.thread_id`` is provided, the run is created
on the given thread so that conversation history is preserved.
Otherwise a new temporary thread is created.
"""
thread_id = _resolve_thread_id(body)
record = await start_run(body, thread_id, request)
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}

View File

@@ -0,0 +1,356 @@
import json
import logging
import shutil
from pathlib import Path
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
from deerflow.skills.manager import (
append_history,
atomic_write,
custom_skill_exists,
ensure_custom_skill_is_editable,
get_custom_skill_dir,
get_custom_skill_file,
get_skill_history_file,
read_custom_skill_content,
read_history,
validate_skill_markdown_content,
)
from deerflow.skills.security_scanner import scan_skill_content
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["skills"])
class SkillResponse(BaseModel):
"""Response model for skill information."""
name: str = Field(..., description="Name of the skill")
description: str = Field(..., description="Description of what the skill does")
license: str | None = Field(None, description="License information")
category: str = Field(..., description="Category of the skill (public or custom)")
enabled: bool = Field(default=True, description="Whether this skill is enabled")
class SkillsListResponse(BaseModel):
"""Response model for listing all skills."""
skills: list[SkillResponse]
class SkillUpdateRequest(BaseModel):
"""Request model for updating a skill."""
enabled: bool = Field(..., description="Whether to enable or disable the skill")
class SkillInstallRequest(BaseModel):
"""Request model for installing a skill from a .skill file."""
thread_id: str = Field(..., description="The thread ID where the .skill file is located")
path: str = Field(..., description="Virtual path to the .skill file (e.g., mnt/user-data/outputs/my-skill.skill)")
class SkillInstallResponse(BaseModel):
"""Response model for skill installation."""
success: bool = Field(..., description="Whether the installation was successful")
skill_name: str = Field(..., description="Name of the installed skill")
message: str = Field(..., description="Installation result message")
class CustomSkillContentResponse(SkillResponse):
content: str = Field(..., description="Raw SKILL.md content")
class CustomSkillUpdateRequest(BaseModel):
content: str = Field(..., description="Replacement SKILL.md content")
class CustomSkillHistoryResponse(BaseModel):
history: list[dict]
class SkillRollbackRequest(BaseModel):
history_index: int = Field(default=-1, description="History entry index to restore from, defaulting to the latest change.")
def _skill_to_response(skill: Skill) -> SkillResponse:
"""Convert a Skill object to a SkillResponse."""
return SkillResponse(
name=skill.name,
description=skill.description,
license=skill.license,
category=skill.category,
enabled=skill.enabled,
)
@router.get(
"/skills",
response_model=SkillsListResponse,
summary="List All Skills",
description="Retrieve a list of all available skills from both public and custom directories.",
)
async def list_skills() -> SkillsListResponse:
try:
skills = load_skills(enabled_only=False)
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error(f"Failed to load skills: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to load skills: {str(e)}")
@router.post(
"/skills/install",
response_model=SkillInstallResponse,
summary="Install Skill",
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
)
async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
try:
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
result = install_skill_from_archive(skill_file_path)
await refresh_skills_system_prompt_cache_async()
return SkillInstallResponse(**result)
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except SkillAlreadyExistsError as e:
raise HTTPException(status_code=409, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to install skill: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to install skill: {str(e)}")
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
async def list_custom_skills() -> SkillsListResponse:
try:
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error("Failed to list custom skills: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to list custom skills: {str(e)}")
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
try:
skills = load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
if skill is None:
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name))
except HTTPException:
raise
except Exception as e:
logger.error("Failed to get custom skill %s: %s", skill_name, e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get custom skill: {str(e)}")
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse:
try:
ensure_custom_skill_is_editable(skill_name)
validate_skill_markdown_content(skill_name, request.content)
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md")
if scan.decision == "block":
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md"
prev_content = skill_file.read_text(encoding="utf-8")
atomic_write(skill_file, request.content)
append_history(
skill_name,
{
"action": "human_edit",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": prev_content,
"new_content": request.content,
"scanner": {"decision": scan.decision, "reason": scan.reason},
},
)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
except HTTPException:
raise
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Failed to update custom skill %s: %s", skill_name, e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to update custom skill: {str(e)}")
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
try:
ensure_custom_skill_is_editable(skill_name)
skill_dir = get_custom_skill_dir(skill_name)
prev_content = read_custom_skill_content(skill_name)
append_history(
skill_name,
{
"action": "human_delete",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": prev_content,
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
)
shutil.rmtree(skill_dir)
await refresh_skills_system_prompt_cache_async()
return {"success": True}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Failed to delete custom skill %s: %s", skill_name, e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to delete custom skill: {str(e)}")
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillHistoryResponse(history=read_history(skill_name))
except HTTPException:
raise
except Exception as e:
logger.error("Failed to read history for %s: %s", skill_name, e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to read history: {str(e)}")
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
history = read_history(skill_name)
if not history:
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
record = history[request.history_index]
target_content = record.get("prev_content")
if target_content is None:
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
validate_skill_markdown_content(skill_name, target_content)
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md")
skill_file = get_custom_skill_file(skill_name)
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
history_entry = {
"action": "rollback",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": current_content,
"new_content": target_content,
"rollback_from_ts": record.get("ts"),
"scanner": {"decision": scan.decision, "reason": scan.reason},
}
if scan.decision == "block":
append_history(skill_name, history_entry)
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
atomic_write(skill_file, target_content)
append_history(skill_name, history_entry)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
except HTTPException:
raise
except IndexError:
raise HTTPException(status_code=400, detail="history_index is out of range")
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Failed to roll back custom skill %s: %s", skill_name, e, exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to roll back custom skill: {str(e)}")
@router.get(
"/skills/{skill_name}",
response_model=SkillResponse,
summary="Get Skill Details",
description="Retrieve detailed information about a specific skill by its name.",
)
async def get_skill(skill_name: str) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
return _skill_to_response(skill)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get skill {skill_name}: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get skill: {str(e)}")
@router.put(
"/skills/{skill_name}",
response_model=SkillResponse,
summary="Update Skill",
description="Update a skill's enabled status by modifying the extensions_config.json file.",
)
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
raise HTTPException(status_code=404, detail=f"Skill '{skill_name}' not found")
config_path = ExtensionsConfig.resolve_config_path()
if config_path is None:
config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
extensions_config = get_extensions_config()
extensions_config.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
config_data = {
"mcpServers": {name: server.model_dump() for name, server in extensions_config.mcp_servers.items()},
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in extensions_config.skills.items()},
}
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
logger.info(f"Skills configuration updated and saved to: {config_path}")
reload_extensions_config()
await refresh_skills_system_prompt_cache_async()
skills = load_skills(enabled_only=False)
updated_skill = next((s for s in skills if s.name == skill_name), None)
if updated_skill is None:
raise HTTPException(status_code=500, detail=f"Failed to reload skill '{skill_name}' after update")
logger.info(f"Skill '{skill_name}' enabled status updated to {request.enabled}")
return _skill_to_response(updated_skill)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to update skill {skill_name}: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to update skill: {str(e)}")

View File

@@ -0,0 +1,132 @@
import json
import logging
from fastapi import APIRouter
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["suggestions"])
class SuggestionMessage(BaseModel):
role: str = Field(..., description="Message role: user|assistant")
content: str = Field(..., description="Message content as plain text")
class SuggestionsRequest(BaseModel):
messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages")
n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate")
model_name: str | None = Field(default=None, description="Optional model override")
class SuggestionsResponse(BaseModel):
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
def _strip_markdown_code_fence(text: str) -> str:
stripped = text.strip()
if not stripped.startswith("```"):
return stripped
lines = stripped.splitlines()
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
return "\n".join(lines[1:-1]).strip()
return stripped
def _parse_json_string_list(text: str) -> list[str] | None:
candidate = _strip_markdown_code_fence(text)
start = candidate.find("[")
end = candidate.rfind("]")
if start == -1 or end == -1 or end <= start:
return None
candidate = candidate[start : end + 1]
try:
data = json.loads(candidate)
except Exception:
return None
if not isinstance(data, list):
return None
out: list[str] = []
for item in data:
if not isinstance(item, str):
continue
s = item.strip()
if not s:
continue
out.append(s)
return out
def _extract_response_text(content: object) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict) and block.get("type") in {"text", "output_text"}:
text = block.get("text")
if isinstance(text, str):
parts.append(text)
return "\n".join(parts) if parts else ""
if content is None:
return ""
return str(content)
def _format_conversation(messages: list[SuggestionMessage]) -> str:
parts: list[str] = []
for m in messages:
role = m.role.strip().lower()
if role in ("user", "human"):
parts.append(f"User: {m.content.strip()}")
elif role in ("assistant", "ai"):
parts.append(f"Assistant: {m.content.strip()}")
else:
parts.append(f"{m.role}: {m.content.strip()}")
return "\n".join(parts).strip()
@router.post(
"/threads/{thread_id}/suggestions",
response_model=SuggestionsResponse,
summary="Generate Follow-up Questions",
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
)
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
if not request.messages:
return SuggestionsResponse(suggestions=[])
n = request.n
conversation = _format_conversation(request.messages)
if not conversation:
return SuggestionsResponse(suggestions=[])
system_instruction = (
"You are generating follow-up questions to help the user continue the conversation.\n"
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
"Requirements:\n"
"- Questions must be relevant to the preceding conversation.\n"
"- Questions must be written in the same language as the user.\n"
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
"- Do NOT include numbering, markdown, or any extra text.\n"
"- Output MUST be a JSON array of strings only.\n"
)
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try:
model = create_chat_model(name=request.model_name, thinking_enabled=False)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or []
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
cleaned = cleaned[:n]
return SuggestionsResponse(suggestions=cleaned)
except Exception as exc:
logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc)
return SuggestionsResponse(suggestions=[])

View File

@@ -0,0 +1,267 @@
"""Runs endpoints — create, stream, wait, cancel.
Implements the LangGraph Platform runs API on top of
:class:`deerflow.agents.runs.RunManager` and
:class:`deerflow.agents.stream_bridge.StreamBridge`.
SSE format is aligned with the LangGraph Platform protocol so that
the ``useStream`` React hook from ``@langchain/langgraph-sdk/react``
works without modification.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Literal
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import RunRecord, serialize_channel_values
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["runs"])
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class RunCreateRequest(BaseModel):
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
input: dict[str, Any] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
webhook: str | None = Field(default=None, description="Completion callback URL")
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
after_seconds: float | None = Field(default=None, description="Delayed execution")
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
class RunResponse(BaseModel):
run_id: str
thread_id: str
assistant_id: str | None = None
status: str
metadata: dict[str, Any] = Field(default_factory=dict)
kwargs: dict[str, Any] = Field(default_factory=dict)
multitask_strategy: str = "reject"
created_at: str = ""
updated_at: str = ""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _record_to_response(record: RunRecord) -> RunResponse:
return RunResponse(
run_id=record.run_id,
thread_id=record.thread_id,
assistant_id=record.assistant_id,
status=record.status.value,
metadata=record.metadata,
kwargs=record.kwargs,
multitask_strategy=record.multitask_strategy,
created_at=record.created_at,
updated_at=record.updated_at,
)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/{thread_id}/runs", response_model=RunResponse)
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
"""Create a background run (returns immediately)."""
record = await start_run(body, thread_id, request)
return _record_to_response(record)
@router.post("/{thread_id}/runs/stream")
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE.
The response includes a ``Content-Location`` header with the run's
resource URL, matching the LangGraph Platform protocol. The
``useStream`` React hook uses this to extract run metadata.
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
# LangGraph Platform includes run metadata in this header.
# The SDK uses a greedy regex to extract the run id from this path,
# so it must point at the canonical run resource without extra suffixes.
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
},
)
@router.post("/{thread_id}/runs/wait", response_model=dict)
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until it completes, returning the final state."""
record = await start_run(body, thread_id, request)
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
"""List all runs for a thread."""
run_mgr = get_run_manager(request)
records = await run_mgr.list_by_thread(thread_id)
return [_record_to_response(r) for r in records]
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
"""Get details of a specific run."""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return _record_to_response(record)
@router.post("/{thread_id}/runs/{run_id}/cancel")
async def cancel_run(
thread_id: str,
run_id: str,
request: Request,
wait: bool = Query(default=False, description="Block until run completes after cancel"),
action: Literal["interrupt", "rollback"] = Query(default="interrupt", description="Cancel action"),
) -> Response:
"""Cancel a running or pending run.
- action=interrupt: Stop execution, keep current checkpoint (can be resumed)
- action=rollback: Stop execution, revert to pre-run checkpoint state
- wait=true: Block until the run fully stops, return 204
- wait=false: Return immediately with 202
"""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
cancelled = await run_mgr.cancel(run_id, action=action)
if not cancelled:
raise HTTPException(
status_code=409,
detail=f"Run {run_id} is not cancellable (status: {record.status.value})",
)
if wait and record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
return Response(status_code=204)
return Response(status_code=202)
@router.get("/{thread_id}/runs/{run_id}/join")
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
"""Join an existing run's SSE stream."""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
async def stream_existing_run(
thread_id: str,
run_id: str,
request: Request,
action: Literal["interrupt", "rollback"] | None = Query(default=None, description="Cancel action"),
wait: int = Query(default=0, description="Block until cancelled (1) or return immediately (0)"),
):
"""Join an existing run's SSE stream (GET), or cancel-then-stream (POST).
The LangGraph SDK's ``joinStream`` and ``useStream`` stop button both use
``POST`` to this endpoint. When ``action=interrupt`` or ``action=rollback``
is present the run is cancelled first; the response then streams any
remaining buffered events so the client observes a clean shutdown.
"""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
# Cancel if an action was requested (stop-button / interrupt flow)
if action is not None:
cancelled = await run_mgr.cancel(run_id, action=action)
if cancelled and wait and record.task is not None:
try:
await record.task
except (asyncio.CancelledError, Exception):
pass
return Response(status_code=204)
bridge = get_stream_bridge(request)
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)

View File

@@ -0,0 +1,682 @@
"""Thread CRUD, state, and history endpoints.
Combines the existing thread-local filesystem cleanup with LangGraph
Platform-compatible thread management backed by the checkpointer.
Channel values returned in state responses are serialized through
:func:`deerflow.runtime.serialization.serialize_channel_values` to
ensure LangChain message objects are converted to JSON-safe dicts
matching the LangGraph Platform wire format expected by the
``useStream`` React hook.
"""
from __future__ import annotations
import logging
import time
import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from app.gateway.deps import get_checkpointer, get_store
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
# ---------------------------------------------------------------------------
# Store namespace
# ---------------------------------------------------------------------------
THREADS_NS: tuple[str, ...] = ("threads",)
"""Namespace used by the Store for thread metadata records."""
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["threads"])
# ---------------------------------------------------------------------------
# Response / request models
# ---------------------------------------------------------------------------
class ThreadDeleteResponse(BaseModel):
"""Response model for thread cleanup."""
success: bool
message: str
class ThreadResponse(BaseModel):
"""Response model for a single thread."""
thread_id: str = Field(description="Unique thread identifier")
status: str = Field(default="idle", description="Thread status: idle, busy, interrupted, error")
created_at: str = Field(default="", description="ISO timestamp")
updated_at: str = Field(default="", description="ISO timestamp")
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
values: dict[str, Any] = Field(default_factory=dict, description="Current state channel values")
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
class ThreadCreateRequest(BaseModel):
"""Request body for creating a thread."""
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
class ThreadSearchRequest(BaseModel):
"""Request body for searching threads."""
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
offset: int = Field(default=0, ge=0, description="Pagination offset")
status: str | None = Field(default=None, description="Filter by thread status")
class ThreadStateResponse(BaseModel):
"""Response model for thread state."""
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
next: list[str] = Field(default_factory=list, description="Next tasks to execute")
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
class ThreadPatchRequest(BaseModel):
"""Request body for patching thread metadata."""
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge")
class ThreadStateUpdateRequest(BaseModel):
"""Request body for updating thread state (human-in-the-loop resume)."""
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
as_node: str | None = Field(default=None, description="Node identity for the update")
class HistoryEntry(BaseModel):
"""Single checkpoint history entry."""
checkpoint_id: str
parent_checkpoint_id: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
values: dict[str, Any] = Field(default_factory=dict)
created_at: str | None = None
next: list[str] = Field(default_factory=list)
class ThreadHistoryRequest(BaseModel):
"""Request body for checkpoint history."""
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
before: str | None = Field(default=None, description="Cursor for pagination")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread."""
path_manager = paths or get_paths()
try:
path_manager.delete_thread_dir(thread_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
except FileNotFoundError:
# Not critical — thread data may not exist on disk
logger.debug("No local thread data to delete for %s", thread_id)
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
except Exception as exc:
logger.exception("Failed to delete thread data for %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
logger.info("Deleted local thread data for %s", thread_id)
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
async def _store_get(store, thread_id: str) -> dict | None:
"""Fetch a thread record from the Store; returns ``None`` if absent."""
item = await store.aget(THREADS_NS, thread_id)
return item.value if item is not None else None
async def _store_put(store, record: dict) -> None:
"""Write a thread record to the Store."""
await store.aput(THREADS_NS, record["thread_id"], record)
async def _store_upsert(store, thread_id: str, *, metadata: dict | None = None, values: dict | None = None) -> None:
"""Create or refresh a thread record in the Store.
On creation the record is written with ``status="idle"``. On update only
``updated_at`` (and optionally ``metadata`` / ``values``) are changed so
that existing fields are preserved.
``values`` carries the agent-state snapshot exposed to the frontend
(currently just ``{"title": "..."}``).
"""
now = time.time()
existing = await _store_get(store, thread_id)
if existing is None:
await _store_put(
store,
{
"thread_id": thread_id,
"status": "idle",
"created_at": now,
"updated_at": now,
"metadata": metadata or {},
"values": values or {},
},
)
else:
val = dict(existing)
val["updated_at"] = now
if metadata:
val.setdefault("metadata", {}).update(metadata)
if values:
val.setdefault("values", {}).update(values)
await _store_put(store, val)
def _derive_thread_status(checkpoint_tuple) -> str:
"""Derive thread status from checkpoint metadata."""
if checkpoint_tuple is None:
return "idle"
pending_writes = getattr(checkpoint_tuple, "pending_writes", None) or []
# Check for error in pending writes
for pw in pending_writes:
if len(pw) >= 2 and pw[1] == "__error__":
return "error"
# Check for pending next tasks (indicates interrupt)
tasks = getattr(checkpoint_tuple, "tasks", None)
if tasks:
return "interrupted"
return "idle"
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread.
Cleans DeerFlow-managed thread directories, removes checkpoint data,
and removes the thread record from the Store.
"""
# Clean local filesystem
response = _delete_thread_data(thread_id)
# Remove from Store (best-effort)
store = get_store(request)
if store is not None:
try:
await store.adelete(THREADS_NS, thread_id)
except Exception:
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
# Remove checkpoints (best-effort)
checkpointer = getattr(request.app.state, "checkpointer", None)
if checkpointer is not None:
try:
if hasattr(checkpointer, "adelete_thread"):
await checkpointer.adelete_thread(thread_id)
except Exception:
logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id)
return response
@router.post("", response_model=ThreadResponse)
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
"""Create a new thread.
The thread record is written to the Store (for fast listing) and an
empty checkpoint is written to the checkpointer (for state reads).
Idempotent: returns the existing record when ``thread_id`` already exists.
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
thread_id = body.thread_id or str(uuid.uuid4())
now = time.time()
# Idempotency: return existing record from Store when already present
if store is not None:
existing_record = await _store_get(store, thread_id)
if existing_record is not None:
return ThreadResponse(
thread_id=thread_id,
status=existing_record.get("status", "idle"),
created_at=str(existing_record.get("created_at", "")),
updated_at=str(existing_record.get("updated_at", "")),
metadata=existing_record.get("metadata", {}),
)
# Write thread record to Store
if store is not None:
try:
await _store_put(
store,
{
"thread_id": thread_id,
"status": "idle",
"created_at": now,
"updated_at": now,
"metadata": body.metadata,
},
)
except Exception:
logger.exception("Failed to write thread %s to store", thread_id)
raise HTTPException(status_code=500, detail="Failed to create thread")
# Write an empty checkpoint so state endpoints work immediately
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
from langgraph.checkpoint.base import empty_checkpoint
ckpt_metadata = {
"step": -1,
"source": "input",
"writes": None,
"parents": {},
**body.metadata,
"created_at": now,
}
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
except Exception:
logger.exception("Failed to create checkpoint for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to create thread")
logger.info("Thread created: %s", thread_id)
return ThreadResponse(
thread_id=thread_id,
status="idle",
created_at=str(now),
updated_at=str(now),
metadata=body.metadata,
)
@router.post("/search", response_model=list[ThreadResponse])
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
"""Search and list threads.
Two-phase approach:
**Phase 1 — Store (fast path, O(threads))**: returns threads that were
created or run through this Gateway. Store records are tiny metadata
dicts so fetching all of them at once is cheap.
**Phase 2 — Checkpointer supplement (lazy migration)**: threads that
were created directly by LangGraph Server (and therefore absent from the
Store) are discovered here by iterating the shared checkpointer. Any
newly found thread is immediately written to the Store so that the next
search skips Phase 2 for that thread — the Store converges to a full
index over time without a one-shot migration job.
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
# -----------------------------------------------------------------------
# Phase 1: Store
# -----------------------------------------------------------------------
merged: dict[str, ThreadResponse] = {}
if store is not None:
try:
items = await store.asearch(THREADS_NS, limit=10_000)
except Exception:
logger.warning("Store search failed — falling back to checkpointer only", exc_info=True)
items = []
for item in items:
val = item.value
merged[val["thread_id"]] = ThreadResponse(
thread_id=val["thread_id"],
status=val.get("status", "idle"),
created_at=str(val.get("created_at", "")),
updated_at=str(val.get("updated_at", "")),
metadata=val.get("metadata", {}),
values=val.get("values", {}),
)
# -----------------------------------------------------------------------
# Phase 2: Checkpointer supplement
# Discovers threads not yet in the Store (e.g. created by LangGraph
# Server) and lazily migrates them so future searches skip this phase.
# -----------------------------------------------------------------------
try:
async for checkpoint_tuple in checkpointer.alist(None):
cfg = getattr(checkpoint_tuple, "config", {})
thread_id = cfg.get("configurable", {}).get("thread_id")
if not thread_id or thread_id in merged:
continue
# Skip sub-graph checkpoints (checkpoint_ns is non-empty for those)
if cfg.get("configurable", {}).get("checkpoint_ns", ""):
continue
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
# Strip LangGraph internal keys from the user-visible metadata dict
user_meta = {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
# Extract state values (title) from the checkpoint's channel_values
checkpoint_data = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint_data.get("channel_values", {})
ckpt_values = {}
if title := channel_values.get("title"):
ckpt_values["title"] = title
thread_resp = ThreadResponse(
thread_id=thread_id,
status=_derive_thread_status(checkpoint_tuple),
created_at=str(ckpt_meta.get("created_at", "")),
updated_at=str(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
metadata=user_meta,
values=ckpt_values,
)
merged[thread_id] = thread_resp
# Lazy migration — write to Store so the next search finds it there
if store is not None:
try:
await _store_upsert(store, thread_id, metadata=user_meta, values=ckpt_values or None)
except Exception:
logger.debug("Failed to migrate thread %s to store (non-fatal)", thread_id)
except Exception:
logger.exception("Checkpointer scan failed during thread search")
# Don't raise — return whatever was collected from Store + partial scan
# -----------------------------------------------------------------------
# Phase 3: Filter → sort → paginate
# -----------------------------------------------------------------------
results = list(merged.values())
if body.metadata:
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
if body.status:
results = [r for r in results if r.status == body.status]
results.sort(key=lambda r: r.updated_at, reverse=True)
return results[body.offset : body.offset + body.limit]
@router.patch("/{thread_id}", response_model=ThreadResponse)
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
"""Merge metadata into a thread record."""
store = get_store(request)
if store is None:
raise HTTPException(status_code=503, detail="Store not available")
record = await _store_get(store, thread_id)
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
now = time.time()
updated = dict(record)
updated.setdefault("metadata", {}).update(body.metadata)
updated["updated_at"] = now
try:
await _store_put(store, updated)
except Exception:
logger.exception("Failed to patch thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to update thread")
return ThreadResponse(
thread_id=thread_id,
status=updated.get("status", "idle"),
created_at=str(updated.get("created_at", "")),
updated_at=str(now),
metadata=updated.get("metadata", {}),
)
@router.get("/{thread_id}", response_model=ThreadResponse)
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
"""Get thread info.
Reads metadata from the Store and derives the accurate execution
status from the checkpointer. Falls back to the checkpointer alone
for threads that pre-date Store adoption (backward compat).
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
record: dict | None = None
if store is not None:
record = await _store_get(store, thread_id)
# Derive accurate status from the checkpointer
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get checkpoint for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to get thread")
if record is None and checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# If the thread exists in the checkpointer but not the store (e.g. legacy
# data), synthesize a minimal store record from the checkpoint metadata.
if record is None and checkpoint_tuple is not None:
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
record = {
"thread_id": thread_id,
"status": "idle",
"created_at": ckpt_meta.get("created_at", ""),
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
}
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
channel_values = checkpoint.get("channel_values", {})
return ThreadResponse(
thread_id=thread_id,
status=status,
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
values=serialize_channel_values(channel_values),
)
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
"""Get the latest state snapshot for a thread.
Channel values are serialized to ensure LangChain message objects
are converted to JSON-safe dicts.
"""
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get state for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
checkpoint_id = None
ckpt_config = getattr(checkpoint_tuple, "config", {})
if ckpt_config:
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
channel_values = checkpoint.get("channel_values", {})
parent_config = getattr(checkpoint_tuple, "parent_config", None)
parent_checkpoint_id = None
if parent_config:
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id")
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
return ThreadStateResponse(
values=serialize_channel_values(channel_values),
next=next_tasks,
metadata=metadata,
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_checkpoint_id,
created_at=str(metadata.get("created_at", "")),
tasks=tasks,
)
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
Writes a new checkpoint that merges *body.values* into the latest
channel values, then syncs any updated ``title`` field back to the Store
so that ``/threads/search`` reflects the change immediately.
"""
checkpointer = get_checkpointer(request)
store = get_store(request)
# checkpoint_ns must be present in the config for aput — default to ""
# (the root graph namespace). checkpoint_id is optional; omitting it
# fetches the latest checkpoint for the thread.
read_config: dict[str, Any] = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "",
}
}
if body.checkpoint_id:
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
try:
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
except Exception:
logger.exception("Failed to get state for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# Work on mutable copies so we don't accidentally mutate cached objects.
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
if body.values:
channel_values.update(body.values)
checkpoint["channel_values"] = channel_values
metadata["updated_at"] = time.time()
if body.as_node:
metadata["source"] = "update"
metadata["step"] = metadata.get("step", 0) + 1
metadata["writes"] = {body.as_node: body.values}
# aput requires checkpoint_ns in the config — use the same config used for the
# read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id
# so that aput generates a fresh checkpoint ID for the new snapshot.
write_config: dict[str, Any] = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "",
}
}
try:
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
except Exception:
logger.exception("Failed to update state for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to update thread state")
new_checkpoint_id: str | None = None
if isinstance(new_config, dict):
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
# Sync title changes to the Store so /threads/search reflects them immediately.
if store is not None and body.values and "title" in body.values:
try:
await _store_upsert(store, thread_id, values={"title": body.values["title"]})
except Exception:
logger.debug("Failed to sync title to store for thread %s (non-fatal)", thread_id)
return ThreadStateResponse(
values=serialize_channel_values(channel_values),
next=[],
metadata=metadata,
checkpoint_id=new_checkpoint_id,
created_at=str(metadata.get("created_at", "")),
)
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
"""Get checkpoint history for a thread."""
checkpointer = get_checkpointer(request)
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
if body.before:
config["configurable"]["checkpoint_id"] = body.before
entries: list[HistoryEntry] = []
try:
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
ckpt_config = getattr(checkpoint_tuple, "config", {})
parent_config = getattr(checkpoint_tuple, "parent_config", None)
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
parent_id = None
if parent_config:
parent_id = parent_config.get("configurable", {}).get("checkpoint_id")
channel_values = checkpoint.get("channel_values", {})
# Derive next tasks
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
entries.append(
HistoryEntry(
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_id,
metadata=metadata,
values=serialize_channel_values(channel_values),
created_at=str(metadata.get("created_at", "")),
next=next_tasks,
)
)
except Exception:
logger.exception("Failed to get history for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to get thread history")
return entries

View File

@@ -0,0 +1,168 @@
"""Upload router for handling file uploads."""
import logging
import os
import stat
from fastapi import APIRouter, File, HTTPException, UploadFile
from pydantic import BaseModel
from deerflow.config.paths import get_paths
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.uploads.manager import (
PathTraversalError,
delete_file_safe,
enrich_file_listing,
ensure_uploads_dir,
get_uploads_dir,
list_files_in_dir,
normalize_filename,
upload_artifact_url,
upload_virtual_path,
)
from deerflow.utils.file_conversion import CONVERTIBLE_EXTENSIONS, convert_file_to_markdown
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads/{thread_id}/uploads", tags=["uploads"])
class UploadResponse(BaseModel):
"""Response model for file upload."""
success: bool
files: list[dict[str, str]]
message: str
def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
"""Ensure uploaded files remain writable when mounted into non-local sandboxes.
In AIO sandbox mode, the gateway writes the authoritative host-side file
first, then the sandbox runtime may rewrite the same mounted path. Granting
world-writable access here prevents permission mismatches between the
gateway user and the sandbox runtime user.
"""
file_stat = os.lstat(file_path)
if stat.S_ISLNK(file_stat.st_mode):
logger.warning("Skipping sandbox chmod for symlinked upload path: %s", file_path)
return
writable_mode = stat.S_IMODE(file_stat.st_mode) | stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH
chmod_kwargs = {"follow_symlinks": False} if os.chmod in os.supports_follow_symlinks else {}
os.chmod(file_path, writable_mode, **chmod_kwargs)
@router.post("", response_model=UploadResponse)
async def upload_files(
thread_id: str,
files: list[UploadFile] = File(...),
) -> UploadResponse:
"""Upload multiple files to a thread's uploads directory."""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
try:
uploads_dir = ensure_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
uploaded_files = []
sandbox_provider = get_sandbox_provider()
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
for file in files:
if not file.filename:
continue
try:
safe_filename = normalize_filename(file.filename)
except ValueError:
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
continue
try:
content = await file.read()
file_path = uploads_dir / safe_filename
file_path.write_bytes(content)
virtual_path = upload_virtual_path(safe_filename)
if sandbox_id != "local":
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, content)
file_info = {
"filename": safe_filename,
"size": str(len(content)),
"path": str(sandbox_uploads / safe_filename),
"virtual_path": virtual_path,
"artifact_url": upload_artifact_url(thread_id, safe_filename),
}
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
file_ext = file_path.suffix.lower()
if file_ext in CONVERTIBLE_EXTENSIONS:
md_path = await convert_file_to_markdown(file_path)
if md_path:
md_virtual_path = upload_virtual_path(md_path.name)
if sandbox_id != "local":
_make_file_sandbox_writable(md_path)
sandbox.update_file(md_virtual_path, md_path.read_bytes())
file_info["markdown_file"] = md_path.name
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
file_info["markdown_virtual_path"] = md_virtual_path
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
uploaded_files.append(file_info)
except Exception as e:
logger.error(f"Failed to upload {file.filename}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
return UploadResponse(
success=True,
files=uploaded_files,
message=f"Successfully uploaded {len(uploaded_files)} file(s)",
)
@router.get("/list", response_model=dict)
async def list_uploaded_files(thread_id: str) -> dict:
"""List all files in a thread's uploads directory."""
try:
uploads_dir = get_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
result = list_files_in_dir(uploads_dir)
enrich_file_listing(result, thread_id)
# Gateway additionally includes the sandbox-relative path.
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
for f in result["files"]:
f["path"] = str(sandbox_uploads / f["filename"])
return result
@router.delete("/{filename}")
async def delete_uploaded_file(thread_id: str, filename: str) -> dict:
"""Delete a file from a thread's uploads directory."""
try:
uploads_dir = get_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
try:
return delete_file_safe(uploads_dir, filename, convertible_extensions=CONVERTIBLE_EXTENSIONS)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
except PathTraversalError:
raise HTTPException(status_code=400, detail="Invalid path")
except Exception as e:
logger.error(f"Failed to delete {filename}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to delete {filename}: {str(e)}")

View File

@@ -0,0 +1,367 @@
"""Run lifecycle service layer.
Centralizes the business logic for creating runs, formatting SSE
frames, and consuming stream bridge events. Router modules
(``thread_runs``, ``runs``) are thin HTTP handlers that delegate here.
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import time
from typing import Any
from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
from deerflow.runtime import (
END_SENTINEL,
HEARTBEAT_SENTINEL,
ConflictError,
DisconnectMode,
RunManager,
RunRecord,
RunStatus,
StreamBridge,
UnsupportedStrategyError,
run_agent,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# SSE formatting
# ---------------------------------------------------------------------------
def format_sse(event: str, data: Any, *, event_id: str | None = None) -> str:
"""Format a single SSE frame.
Field order: ``event:`` -> ``data:`` -> ``id:`` (optional) -> blank line.
This matches the LangGraph Platform wire format consumed by the
``useStream`` React hook and the Python ``langgraph-sdk`` SSE decoder.
"""
payload = json.dumps(data, default=str, ensure_ascii=False)
parts = [f"event: {event}", f"data: {payload}"]
if event_id:
parts.append(f"id: {event_id}")
parts.append("")
parts.append("")
return "\n".join(parts)
# ---------------------------------------------------------------------------
# Input / config helpers
# ---------------------------------------------------------------------------
def normalize_stream_modes(raw: list[str] | str | None) -> list[str]:
"""Normalize the stream_mode parameter to a list.
Default matches what ``useStream`` expects: values + messages-tuple.
"""
if raw is None:
return ["values"]
if isinstance(raw, str):
return [raw]
return raw if raw else ["values"]
def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
"""Convert LangGraph Platform input format to LangChain state dict."""
if raw_input is None:
return {}
messages = raw_input.get("messages")
if messages and isinstance(messages, list):
converted = []
for msg in messages:
if isinstance(msg, dict):
role = msg.get("role", msg.get("type", "user"))
content = msg.get("content", "")
if role in ("user", "human"):
converted.append(HumanMessage(content=content))
else:
# TODO: handle other message types (system, ai, tool)
converted.append(HumanMessage(content=content))
else:
converted.append(msg)
return {**raw_input, "messages": converted}
return raw_input
_DEFAULT_ASSISTANT_ID = "lead_agent"
def resolve_agent_factory(assistant_id: str | None):
"""Resolve the agent factory callable from config.
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
injected into ``configurable`` — see :func:`build_run_config`. All
``assistant_id`` values therefore map to the same factory; the routing
happens inside ``make_lead_agent`` when it reads ``cfg["agent_name"]``.
"""
from deerflow.agents.lead_agent.agent import make_lead_agent
return make_lead_agent
def build_run_config(
thread_id: str,
request_config: dict[str, Any] | None,
metadata: dict[str, Any] | None,
*,
assistant_id: str | None = None,
) -> dict[str, Any]:
"""Build a RunnableConfig dict for the agent.
When *assistant_id* refers to a custom agent (anything other than
``"lead_agent"`` / ``None``), the name is forwarded as
``configurable["agent_name"]``. ``make_lead_agent`` reads this key to
load the matching ``agents/<name>/SOUL.md`` and per-agent config —
without it the agent silently runs as the default lead agent.
This mirrors the channel manager's ``_resolve_run_params`` logic so that
the LangGraph Platform-compatible HTTP API and the IM channel path behave
identically.
"""
config: dict[str, Any] = {"recursion_limit": 100}
if request_config:
# LangGraph >= 0.6.0 introduced ``context`` as the preferred way to
# pass thread-level data and rejects requests that include both
# ``configurable`` and ``context``. If the caller already sends
# ``context``, honour it and skip our own ``configurable`` dict.
if "context" in request_config:
if "configurable" in request_config:
logger.warning(
"build_run_config: client sent both 'context' and 'configurable'; preferring 'context' (LangGraph >= 0.6.0). thread_id=%s, caller_configurable keys=%s",
thread_id,
list(request_config.get("configurable", {}).keys()),
)
config["context"] = request_config["context"]
else:
configurable = {"thread_id": thread_id}
configurable.update(request_config.get("configurable", {}))
config["configurable"] = configurable
for k, v in request_config.items():
if k not in ("configurable", "context"):
config[k] = v
else:
config["configurable"] = {"thread_id": thread_id}
# Inject custom agent name when the caller specified a non-default assistant.
# Honour an explicit configurable["agent_name"] in the request if already set.
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "configurable" in config:
if "agent_name" not in config["configurable"]:
normalized = assistant_id.strip().lower().replace("_", "-")
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
config["configurable"]["agent_name"] = normalized
if metadata:
config.setdefault("metadata", {}).update(metadata)
return config
# ---------------------------------------------------------------------------
# Run lifecycle
# ---------------------------------------------------------------------------
async def _upsert_thread_in_store(store, thread_id: str, metadata: dict | None) -> None:
"""Create or refresh the thread record in the Store.
Called from :func:`start_run` so that threads created via the stateless
``/runs/stream`` endpoint (which never calls ``POST /threads``) still
appear in ``/threads/search`` results.
"""
# Deferred import to avoid circular import with the threads router module.
from app.gateway.routers.threads import _store_upsert
try:
await _store_upsert(store, thread_id, metadata=metadata)
except Exception:
logger.warning("Failed to upsert thread %s in store (non-fatal)", thread_id)
async def _sync_thread_title_after_run(
run_task: asyncio.Task,
thread_id: str,
checkpointer: Any,
store: Any,
) -> None:
"""Wait for *run_task* to finish, then persist the generated title to the Store.
TitleMiddleware writes the generated title to the LangGraph agent state
(checkpointer) but the Gateway's Store record is not updated automatically.
This coroutine closes that gap by reading the final checkpoint after the
run completes and syncing ``values.title`` into the Store record so that
subsequent ``/threads/search`` responses include the correct title.
Runs as a fire-and-forget :func:`asyncio.create_task`; failures are
logged at DEBUG level and never propagate.
"""
# Wait for the background run task to complete (any outcome).
# asyncio.wait does not propagate task exceptions — it just returns
# when the task is done, cancelled, or failed.
await asyncio.wait({run_task})
# Deferred import to avoid circular import with the threads router module.
from app.gateway.routers.threads import _store_get, _store_put
try:
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
if ckpt_tuple is None:
return
channel_values = ckpt_tuple.checkpoint.get("channel_values", {})
title = channel_values.get("title")
if not title:
return
existing = await _store_get(store, thread_id)
if existing is None:
return
updated = dict(existing)
updated.setdefault("values", {})["title"] = title
updated["updated_at"] = time.time()
await _store_put(store, updated)
logger.debug("Synced title %r for thread %s", title, thread_id)
except Exception:
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id, exc_info=True)
async def start_run(
body: Any,
thread_id: str,
request: Request,
) -> RunRecord:
"""Create a RunRecord and launch the background agent task.
Parameters
----------
body : RunCreateRequest
The validated request body (typed as Any to avoid circular import
with the router module that defines the Pydantic model).
thread_id : str
Target thread.
request : Request
FastAPI request — used to retrieve singletons from ``app.state``.
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
checkpointer = get_checkpointer(request)
store = get_store(request)
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
try:
record = await run_mgr.create_or_reject(
thread_id,
body.assistant_id,
on_disconnect=disconnect,
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
)
except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except UnsupportedStrategyError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc
# Ensure the thread is visible in /threads/search, even for threads that
# were never explicitly created via POST /threads (e.g. stateless runs).
store = get_store(request)
if store is not None:
await _upsert_thread_in_store(store, thread_id, body.metadata)
agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input)
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
# Merge DeerFlow-specific context overrides into configurable.
# The ``context`` field is a custom extension for the langgraph-compat layer
# that carries agent configuration (model_name, thinking_enabled, etc.).
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
context = getattr(body, "context", None)
if context:
_CONTEXT_CONFIGURABLE_KEYS = {
"model_name",
"mode",
"thinking_enabled",
"reasoning_effort",
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
}
configurable = config.setdefault("configurable", {})
for key in _CONTEXT_CONFIGURABLE_KEYS:
if key in context:
configurable.setdefault(key, context[key])
stream_modes = normalize_stream_modes(body.stream_mode)
task = asyncio.create_task(
run_agent(
bridge,
run_mgr,
record,
checkpointer=checkpointer,
store=store,
agent_factory=agent_factory,
graph_input=graph_input,
config=config,
stream_modes=stream_modes,
stream_subgraphs=body.stream_subgraphs,
interrupt_before=body.interrupt_before,
interrupt_after=body.interrupt_after,
)
)
record.task = task
# After the run completes, sync the title generated by TitleMiddleware from
# the checkpointer into the Store record so that /threads/search returns the
# correct title instead of an empty values dict.
if store is not None:
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store))
return record
async def sse_consumer(
bridge: StreamBridge,
record: RunRecord,
request: Request,
run_mgr: RunManager,
):
"""Async generator that yields SSE frames from the bridge.
The ``finally`` block implements ``on_disconnect`` semantics:
- ``cancel``: abort the background task on client disconnect.
- ``continue``: let the task run; events are discarded.
"""
last_event_id = request.headers.get("Last-Event-ID")
try:
async for entry in bridge.subscribe(record.run_id, last_event_id=last_event_id):
if await request.is_disconnected():
break
if entry is HEARTBEAT_SENTINEL:
yield ": heartbeat\n\n"
continue
if entry is END_SENTINEL:
yield format_sse("end", None, event_id=entry.id or None)
return
yield format_sse(entry.event, entry.data, event_id=entry.id or None)
finally:
if record.status in (RunStatus.pending, RunStatus.running):
if record.on_disconnect == DisconnectMode.cancel:
await run_mgr.cancel(record.run_id)