Files
deerflow-factory/deer-flow/backend/tests/test_run_worker_rollback.py
DATA 6de0bf9f5b Initial commit: hardened DeerFlow factory
Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection
hardening:

- New deerflow.security package: content_delimiter, html_cleaner,
  sanitizer (8 layers — invisible chars, control chars, symbols, NFC,
  PUA, tag chars, horizontal whitespace collapse with newline/tab
  preservation, length cap)
- New deerflow.community.searx package: web_search, web_fetch,
  image_search backed by a private SearX instance, every external
  string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>>
  delimiters
- All native community web providers (ddg_search, tavily, exa,
  firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail
  stubs that raise NativeWebToolDisabledError at import time, so a
  misconfigured tool.use path fails loud rather than silently falling
  back to unsanitized output
- Native client back-doors (jina_client.py, infoquest_client.py)
  stubbed too
- Native-tool tests quarantined under tests/_disabled_native/
  (collect_ignore_glob via local conftest.py)
- Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve
  newlines and tabs so list/table structure survives
- Hardened runtime config.yaml references only the searx-backed tools
- Factory overlay (backend/) kept in sync with deer-flow tree as a
  reference / source

See HARDENING.md for the full audit trail and verification steps.
2026-04-12 14:23:57 +02:00

215 lines
7.8 KiB
Python

from unittest.mock import AsyncMock, call
import pytest
from deerflow.runtime.runs.worker import _rollback_to_pre_run_checkpoint
class FakeCheckpointer:
def __init__(self, *, put_result):
self.adelete_thread = AsyncMock()
self.aput = AsyncMock(return_value=put_result)
self.aput_writes = AsyncMock()
@pytest.mark.anyio
async def test_rollback_restores_snapshot_without_deleting_thread():
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {
"id": "ckpt-1",
"channel_versions": {"messages": 3},
"channel_values": {"messages": ["before"]},
},
"metadata": {"source": "input"},
"pending_writes": [
("task-a", "messages", {"content": "first"}),
("task-a", "status", "done"),
("task-b", "events", {"type": "tool"}),
],
},
snapshot_capture_failed=False,
)
checkpointer.adelete_thread.assert_not_awaited()
checkpointer.aput.assert_awaited_once_with(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
{
"id": "ckpt-1",
"channel_versions": {"messages": 3},
"channel_values": {"messages": ["before"]},
},
{"source": "input"},
{"messages": 3},
)
assert checkpointer.aput_writes.await_args_list == [
call(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
[("messages", {"content": "first"}), ("status", "done")],
task_id="task-a",
),
call(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
[("events", {"type": "tool"})],
task_id="task-b",
),
]
@pytest.mark.anyio
async def test_rollback_deletes_thread_when_no_snapshot_exists():
checkpointer = FakeCheckpointer(put_result=None)
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id=None,
pre_run_snapshot=None,
snapshot_capture_failed=False,
)
checkpointer.adelete_thread.assert_awaited_once_with("thread-1")
checkpointer.aput.assert_not_awaited()
checkpointer.aput_writes.assert_not_awaited()
@pytest.mark.anyio
async def test_rollback_raises_when_restore_config_has_no_checkpoint_id():
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}})
with pytest.raises(RuntimeError, match="did not return checkpoint_id"):
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [("task-a", "messages", "value")],
},
snapshot_capture_failed=False,
)
checkpointer.adelete_thread.assert_not_awaited()
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_not_awaited()
@pytest.mark.anyio
async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": None,
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [],
},
snapshot_capture_failed=False,
)
checkpointer.aput.assert_awaited_once_with(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
{"id": "ckpt-1", "channel_versions": {}},
{},
{},
)
@pytest.mark.anyio
async def test_rollback_raises_on_malformed_pending_write_not_a_tuple():
"""pending_writes containing a non-3-tuple item should raise RuntimeError."""
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
with pytest.raises(RuntimeError, match="rollback failed: pending_write is not a 3-tuple"):
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [
("task-a", "messages", "valid"), # valid
["only", "two"], # malformed: only 2 elements
],
},
snapshot_capture_failed=False,
)
# aput succeeded but aput_writes should not be called due to malformed data
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_not_awaited()
@pytest.mark.anyio
async def test_rollback_raises_on_malformed_pending_write_non_string_channel():
"""pending_writes containing a non-string channel should raise RuntimeError."""
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
with pytest.raises(RuntimeError, match="rollback failed: pending_write has non-string channel"):
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [
("task-a", 123, "value"), # malformed: channel is not a string
],
},
snapshot_capture_failed=False,
)
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_not_awaited()
@pytest.mark.anyio
async def test_rollback_propagates_aput_writes_failure():
"""If aput_writes fails, the exception should propagate (not be swallowed)."""
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
# Simulate aput_writes failure
checkpointer.aput_writes.side_effect = RuntimeError("Database connection lost")
with pytest.raises(RuntimeError, match="Database connection lost"):
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [
("task-a", "messages", "value"),
],
},
snapshot_capture_failed=False,
)
# aput succeeded, aput_writes was called but failed
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_awaited_once()