Initial commit: hardened DeerFlow factory
Vendored deer-flow upstream (bytedance/deer-flow) plus prompt-injection hardening: - New deerflow.security package: content_delimiter, html_cleaner, sanitizer (8 layers — invisible chars, control chars, symbols, NFC, PUA, tag chars, horizontal whitespace collapse with newline/tab preservation, length cap) - New deerflow.community.searx package: web_search, web_fetch, image_search backed by a private SearX instance, every external string sanitized and wrapped in <<<EXTERNAL_UNTRUSTED_CONTENT>>> delimiters - All native community web providers (ddg_search, tavily, exa, firecrawl, jina_ai, infoquest, image_search) replaced with hard-fail stubs that raise NativeWebToolDisabledError at import time, so a misconfigured tool.use path fails loud rather than silently falling back to unsanitized output - Native client back-doors (jina_client.py, infoquest_client.py) stubbed too - Native-tool tests quarantined under tests/_disabled_native/ (collect_ignore_glob via local conftest.py) - Sanitizer Layer 7 fix: only collapse horizontal whitespace, preserve newlines and tabs so list/table structure survives - Hardened runtime config.yaml references only the searx-backed tools - Factory overlay (backend/) kept in sync with deer-flow tree as a reference / source See HARDENING.md for the full audit trail and verification steps.
This commit is contained in:
214
deer-flow/backend/tests/test_run_worker_rollback.py
Normal file
214
deer-flow/backend/tests/test_run_worker_rollback.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from unittest.mock import AsyncMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.runs.worker import _rollback_to_pre_run_checkpoint
|
||||
|
||||
|
||||
class FakeCheckpointer:
|
||||
def __init__(self, *, put_result):
|
||||
self.adelete_thread = AsyncMock()
|
||||
self.aput = AsyncMock(return_value=put_result)
|
||||
self.aput_writes = AsyncMock()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {
|
||||
"id": "ckpt-1",
|
||||
"channel_versions": {"messages": 3},
|
||||
"channel_values": {"messages": ["before"]},
|
||||
},
|
||||
"metadata": {"source": "input"},
|
||||
"pending_writes": [
|
||||
("task-a", "messages", {"content": "first"}),
|
||||
("task-a", "status", "done"),
|
||||
("task-b", "events", {"type": "tool"}),
|
||||
],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_not_awaited()
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{
|
||||
"id": "ckpt-1",
|
||||
"channel_versions": {"messages": 3},
|
||||
"channel_values": {"messages": ["before"]},
|
||||
},
|
||||
{"source": "input"},
|
||||
{"messages": 3},
|
||||
)
|
||||
assert checkpointer.aput_writes.await_args_list == [
|
||||
call(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
||||
[("messages", {"content": "first"}), ("status", "done")],
|
||||
task_id="task-a",
|
||||
),
|
||||
call(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
||||
[("events", {"type": "tool"})],
|
||||
task_id="task-b",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_deletes_thread_when_no_snapshot_exists():
|
||||
checkpointer = FakeCheckpointer(put_result=None)
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id=None,
|
||||
pre_run_snapshot=None,
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_awaited_once_with("thread-1")
|
||||
checkpointer.aput.assert_not_awaited()
|
||||
checkpointer.aput_writes.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_raises_when_restore_config_has_no_checkpoint_id():
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}})
|
||||
|
||||
with pytest.raises(RuntimeError, match="did not return checkpoint_id"):
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [("task-a", "messages", "value")],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_not_awaited()
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
checkpointer.aput_writes.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": None,
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{"id": "ckpt-1", "channel_versions": {}},
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_raises_on_malformed_pending_write_not_a_tuple():
|
||||
"""pending_writes containing a non-3-tuple item should raise RuntimeError."""
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
|
||||
with pytest.raises(RuntimeError, match="rollback failed: pending_write is not a 3-tuple"):
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [
|
||||
("task-a", "messages", "valid"), # valid
|
||||
["only", "two"], # malformed: only 2 elements
|
||||
],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
# aput succeeded but aput_writes should not be called due to malformed data
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
checkpointer.aput_writes.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_raises_on_malformed_pending_write_non_string_channel():
|
||||
"""pending_writes containing a non-string channel should raise RuntimeError."""
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
|
||||
with pytest.raises(RuntimeError, match="rollback failed: pending_write has non-string channel"):
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [
|
||||
("task-a", 123, "value"), # malformed: channel is not a string
|
||||
],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
checkpointer.aput_writes.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_propagates_aput_writes_failure():
|
||||
"""If aput_writes fails, the exception should propagate (not be swallowed)."""
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
# Simulate aput_writes failure
|
||||
checkpointer.aput_writes.side_effect = RuntimeError("Database connection lost")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Database connection lost"):
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [
|
||||
("task-a", "messages", "value"),
|
||||
],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
# aput succeeded, aput_writes was called but failed
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
checkpointer.aput_writes.assert_awaited_once()
|
||||
Reference in New Issue
Block a user