from unittest.mock import AsyncMock, call import pytest from deerflow.runtime.runs.worker import _rollback_to_pre_run_checkpoint class FakeCheckpointer: def __init__(self, *, put_result): self.adelete_thread = AsyncMock() self.aput = AsyncMock(return_value=put_result) self.aput_writes = AsyncMock() @pytest.mark.anyio async def test_rollback_restores_snapshot_without_deleting_thread(): checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id="ckpt-1", pre_run_snapshot={ "checkpoint_ns": "", "checkpoint": { "id": "ckpt-1", "channel_versions": {"messages": 3}, "channel_values": {"messages": ["before"]}, }, "metadata": {"source": "input"}, "pending_writes": [ ("task-a", "messages", {"content": "first"}), ("task-a", "status", "done"), ("task-b", "events", {"type": "tool"}), ], }, snapshot_capture_failed=False, ) checkpointer.adelete_thread.assert_not_awaited() checkpointer.aput.assert_awaited_once_with( {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}, { "id": "ckpt-1", "channel_versions": {"messages": 3}, "channel_values": {"messages": ["before"]}, }, {"source": "input"}, {"messages": 3}, ) assert checkpointer.aput_writes.await_args_list == [ call( {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}, [("messages", {"content": "first"}), ("status", "done")], task_id="task-a", ), call( {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}, [("events", {"type": "tool"})], task_id="task-b", ), ] @pytest.mark.anyio async def test_rollback_deletes_thread_when_no_snapshot_exists(): checkpointer = FakeCheckpointer(put_result=None) await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id=None, pre_run_snapshot=None, snapshot_capture_failed=False, ) checkpointer.adelete_thread.assert_awaited_once_with("thread-1") checkpointer.aput.assert_not_awaited() checkpointer.aput_writes.assert_not_awaited() @pytest.mark.anyio async def test_rollback_raises_when_restore_config_has_no_checkpoint_id(): checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}) with pytest.raises(RuntimeError, match="did not return checkpoint_id"): await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id="ckpt-1", pre_run_snapshot={ "checkpoint_ns": "", "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, "metadata": {}, "pending_writes": [("task-a", "messages", "value")], }, snapshot_capture_failed=False, ) checkpointer.adelete_thread.assert_not_awaited() checkpointer.aput.assert_awaited_once() checkpointer.aput_writes.assert_not_awaited() @pytest.mark.anyio async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace(): checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id="ckpt-1", pre_run_snapshot={ "checkpoint_ns": None, "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, "metadata": {}, "pending_writes": [], }, snapshot_capture_failed=False, ) checkpointer.aput.assert_awaited_once_with( {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}, {"id": "ckpt-1", "channel_versions": {}}, {}, {}, ) @pytest.mark.anyio async def test_rollback_raises_on_malformed_pending_write_not_a_tuple(): """pending_writes containing a non-3-tuple item should raise RuntimeError.""" checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) with pytest.raises(RuntimeError, match="rollback failed: pending_write is not a 3-tuple"): await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id="ckpt-1", pre_run_snapshot={ "checkpoint_ns": "", "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, "metadata": {}, "pending_writes": [ ("task-a", "messages", "valid"), # valid ["only", "two"], # malformed: only 2 elements ], }, snapshot_capture_failed=False, ) # aput succeeded but aput_writes should not be called due to malformed data checkpointer.aput.assert_awaited_once() checkpointer.aput_writes.assert_not_awaited() @pytest.mark.anyio async def test_rollback_raises_on_malformed_pending_write_non_string_channel(): """pending_writes containing a non-string channel should raise RuntimeError.""" checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) with pytest.raises(RuntimeError, match="rollback failed: pending_write has non-string channel"): await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id="ckpt-1", pre_run_snapshot={ "checkpoint_ns": "", "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, "metadata": {}, "pending_writes": [ ("task-a", 123, "value"), # malformed: channel is not a string ], }, snapshot_capture_failed=False, ) checkpointer.aput.assert_awaited_once() checkpointer.aput_writes.assert_not_awaited() @pytest.mark.anyio async def test_rollback_propagates_aput_writes_failure(): """If aput_writes fails, the exception should propagate (not be swallowed).""" checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) # Simulate aput_writes failure checkpointer.aput_writes.side_effect = RuntimeError("Database connection lost") with pytest.raises(RuntimeError, match="Database connection lost"): await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id="ckpt-1", pre_run_snapshot={ "checkpoint_ns": "", "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, "metadata": {}, "pending_writes": [ ("task-a", "messages", "value"), ], }, snapshot_capture_failed=False, ) # aput succeeded, aput_writes was called but failed checkpointer.aput.assert_awaited_once() checkpointer.aput_writes.assert_awaited_once()