"""Tests for AioSandbox concurrent command serialization (#1433).""" import threading from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @pytest.fixture() def sandbox(): """Create an AioSandbox with a mocked client.""" with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"): from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox sb = AioSandbox(id="test-sandbox", base_url="http://localhost:8080") return sb class TestExecuteCommandSerialization: """Verify that concurrent exec_command calls are serialized.""" def test_lock_prevents_concurrent_execution(self, sandbox): """Concurrent threads should not overlap inside execute_command.""" call_log = [] barrier = threading.Barrier(3) def slow_exec(command, **kwargs): call_log.append(("enter", command)) import time time.sleep(0.05) call_log.append(("exit", command)) return SimpleNamespace(data=SimpleNamespace(output=f"ok: {command}")) sandbox._client.shell.exec_command = slow_exec def worker(cmd): barrier.wait() # ensure all threads contend for the lock simultaneously sandbox.execute_command(cmd) threads = [] for i in range(3): t = threading.Thread(target=worker, args=(f"cmd-{i}",)) threads.append(t) for t in threads: t.start() for t in threads: t.join() # Verify serialization: each "enter" should be followed by its own # "exit" before the next "enter" (no interleaving). enters = [i for i, (action, _) in enumerate(call_log) if action == "enter"] exits = [i for i, (action, _) in enumerate(call_log) if action == "exit"] assert len(enters) == 3 assert len(exits) == 3 for e_idx, x_idx in zip(enters, exits): assert x_idx == e_idx + 1, f"Interleaved execution detected: {call_log}" class TestErrorObservationRetry: """Verify ErrorObservation detection and fresh-session retry.""" def test_retry_on_error_observation(self, sandbox): """When output contains ErrorObservation, retry with a fresh session.""" call_count = 0 def mock_exec(command, **kwargs): nonlocal call_count call_count += 1 if call_count == 1: return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'")) return SimpleNamespace(data=SimpleNamespace(output="success")) sandbox._client.shell.exec_command = mock_exec result = sandbox.execute_command("echo hello") assert result == "success" assert call_count == 2 def test_retry_passes_fresh_session_id(self, sandbox): """The retry call should include a new session id kwarg.""" calls = [] def mock_exec(command, **kwargs): calls.append(kwargs) if len(calls) == 1: return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'")) return SimpleNamespace(data=SimpleNamespace(output="ok")) sandbox._client.shell.exec_command = mock_exec sandbox.execute_command("test") assert len(calls) == 2 assert "id" not in calls[0] assert "id" in calls[1] assert len(calls[1]["id"]) == 36 # UUID format def test_no_retry_on_clean_output(self, sandbox): """Normal output should not trigger a retry.""" call_count = 0 def mock_exec(command, **kwargs): nonlocal call_count call_count += 1 return SimpleNamespace(data=SimpleNamespace(output="all good")) sandbox._client.shell.exec_command = mock_exec result = sandbox.execute_command("echo hello") assert result == "all good" assert call_count == 1 class TestListDirSerialization: """Verify that list_dir also acquires the lock.""" def test_list_dir_uses_lock(self, sandbox): """list_dir should hold the lock during execution.""" lock_was_held = [] original_exec = MagicMock(return_value=SimpleNamespace(data=SimpleNamespace(output="/a\n/b"))) def tracking_exec(command, **kwargs): lock_was_held.append(sandbox._lock.locked()) return original_exec(command, **kwargs) sandbox._client.shell.exec_command = tracking_exec result = sandbox.list_dir("/test") assert result == ["/a", "/b"] assert lock_was_held == [True], "list_dir must hold the lock during exec_command" class TestConcurrentFileWrites: """Verify file write paths do not lose concurrent updates.""" def test_append_should_preserve_both_parallel_writes(self, sandbox): storage = {"content": "seed\n"} active_reads = 0 state_lock = threading.Lock() overlap_detected = threading.Event() def overlapping_read_file(path): nonlocal active_reads with state_lock: active_reads += 1 snapshot = storage["content"] if active_reads == 2: overlap_detected.set() overlap_detected.wait(0.05) with state_lock: active_reads -= 1 return snapshot def write_back(*, file, content, **kwargs): storage["content"] = content return SimpleNamespace(data=SimpleNamespace()) sandbox.read_file = overlapping_read_file sandbox._client.file.write_file = write_back barrier = threading.Barrier(2) def writer(payload: str): barrier.wait() sandbox.write_file("/tmp/shared.log", payload, append=True) threads = [ threading.Thread(target=writer, args=("A\n",)), threading.Thread(target=writer, args=("B\n",)), ] for thread in threads: thread.start() for thread in threads: thread.join() assert storage["content"] in {"seed\nA\nB\n", "seed\nB\nA\n"}