Skip to content

Commit 325bb71

Browse files
committed
fix: Correct HybridBackgroundTasks interface and behavior
- Fix add_task method to be synchronous (matching FastAPI BackgroundTasks interface) - When use_docket=True: Schedule tasks directly in Docket's Redis queue - When use_docket=False: Use FastAPI's background tasks - Update tests to reflect correct behavior - Fix mypy type checking issues - Maintain proper separation between Docket and FastAPI task execution
1 parent 2f4ea93 commit 325bb71

File tree

2 files changed

+53
-25
lines changed

2 files changed

+53
-25
lines changed

agent_memory_server/dependencies.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from collections.abc import Callable
23
from typing import Any
34

@@ -13,26 +14,43 @@
1314
class HybridBackgroundTasks(BackgroundTasks):
1415
"""A BackgroundTasks implementation that can use either Docket or FastAPI background tasks."""
1516

16-
async def add_task(
17-
self, func: Callable[..., Any], *args: Any, **kwargs: Any
18-
) -> None:
17+
def add_task(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
1918
"""Run tasks either directly, through Docket, or through FastAPI background tasks"""
2019
logger.info("Adding task to background tasks...")
2120

2221
if settings.use_docket:
2322
logger.info("Scheduling task through Docket")
23+
# Schedule task directly in Docket's Redis queue
24+
self._schedule_docket_task(func, *args, **kwargs)
25+
else:
26+
logger.info("Using FastAPI background tasks")
27+
# Use FastAPI's background tasks directly
28+
super().add_task(func, *args, **kwargs)
29+
30+
def _schedule_docket_task(
31+
self, func: Callable[..., Any], *args: Any, **kwargs: Any
32+
) -> None:
33+
"""Schedule a task in Docket's Redis queue"""
34+
35+
async def schedule_task():
2436
from docket import Docket
2537

2638
async with Docket(
2739
name=settings.docket_name,
2840
url=settings.redis_url,
2941
) as docket:
30-
# Schedule task through Docket
42+
# Schedule task in Docket's queue
3143
await docket.add(func)(*args, **kwargs)
32-
else:
33-
logger.info("Using FastAPI background tasks")
34-
# Use FastAPI's background tasks
35-
super().add_task(func, *args, **kwargs)
44+
45+
# Run the async scheduling operation
46+
try:
47+
# Try to get the current event loop
48+
loop = asyncio.get_running_loop()
49+
# If we're in an async context, create a task
50+
loop.create_task(schedule_task())
51+
except RuntimeError:
52+
# If no event loop is running, run it synchronously
53+
asyncio.run(schedule_task())
3654

3755

3856
# Backwards compatibility alias

tests/test_dependencies.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for the dependencies module, particularly HybridBackgroundTasks."""
22

3+
import asyncio
34
from unittest.mock import Mock, patch
45

56
import pytest
@@ -24,8 +25,7 @@ def teardown_method(self):
2425
"""Restore original use_docket setting."""
2526
settings.use_docket = self.original_use_docket
2627

27-
@pytest.mark.asyncio
28-
async def test_add_task_with_fastapi_background_tasks(self):
28+
def test_add_task_with_fastapi_background_tasks(self):
2929
"""Test that tasks are added to FastAPI background tasks when use_docket=False."""
3030
settings.use_docket = False
3131

@@ -35,7 +35,7 @@ def test_func(arg1, arg2=None):
3535
return f"test_func called with {arg1}, {arg2}"
3636

3737
# Add task
38-
await bg_tasks.add_task(test_func, "hello", arg2="world")
38+
bg_tasks.add_task(test_func, "hello", arg2="world")
3939

4040
# Verify task was added to FastAPI background tasks
4141
assert len(bg_tasks.tasks) == 1
@@ -71,8 +71,14 @@ async def execution():
7171
async def test_func(arg1, arg2=None):
7272
return f"test_func called with {arg1}, {arg2}"
7373

74-
# Add task
75-
await bg_tasks.add_task(test_func, "hello", arg2="world")
74+
# Add task - this should schedule directly in Docket
75+
bg_tasks.add_task(test_func, "hello", arg2="world")
76+
77+
# Give the async task a moment to complete
78+
await asyncio.sleep(0.1)
79+
80+
# When use_docket=True, tasks should NOT be added to FastAPI background tasks
81+
assert len(bg_tasks.tasks) == 0
7682

7783
# Verify Docket was used
7884
mock_docket_class.assert_called_once_with(
@@ -81,8 +87,7 @@ async def test_func(arg1, arg2=None):
8187
)
8288
mock_docket_instance.add.assert_called_once_with(test_func)
8389

84-
@pytest.mark.asyncio
85-
async def test_add_task_logs_correctly_for_fastapi(self):
90+
def test_add_task_logs_correctly_for_fastapi(self):
8691
"""Test that FastAPI background tasks work without errors."""
8792
settings.use_docket = False
8893

@@ -92,7 +97,7 @@ def test_func():
9297
pass
9398

9499
# Should not raise any errors
95-
await bg_tasks.add_task(test_func)
100+
bg_tasks.add_task(test_func)
96101

97102
# Verify task was added to FastAPI background tasks
98103
assert len(bg_tasks.tasks) == 1
@@ -122,7 +127,10 @@ def test_func():
122127
pass
123128

124129
# Should not raise any errors
125-
await bg_tasks.add_task(test_func)
130+
bg_tasks.add_task(test_func)
131+
132+
# Give the async task a moment to complete
133+
await asyncio.sleep(0.1)
126134

127135
# Verify Docket was called
128136
mock_docket_instance.add.assert_called_once_with(test_func)
@@ -178,7 +186,7 @@ def test_func():
178186

179187
# Test with use_docket=False
180188
settings.use_docket = False
181-
await bg_tasks.add_task(test_func)
189+
bg_tasks.add_task(test_func)
182190
assert len(bg_tasks.tasks) == 1
183191

184192
# Clear tasks and test with use_docket=True
@@ -199,14 +207,18 @@ async def execution():
199207

200208
mock_docket_instance.add.return_value = mock_task_callable
201209

202-
await bg_tasks.add_task(test_func)
210+
bg_tasks.add_task(test_func)
211+
212+
# Give the async task a moment to complete
213+
await asyncio.sleep(0.1)
203214

204215
# Should not add to FastAPI tasks when use_docket=True
205216
assert len(bg_tasks.tasks) == 0
206217
# Should have called Docket
207218
mock_docket_class.assert_called_once()
208219

209-
def test_uses_correct_docket_settings(self):
220+
@pytest.mark.asyncio
221+
async def test_uses_correct_docket_settings(self):
210222
"""Test that HybridBackgroundTasks uses the correct docket name and URL from settings."""
211223
settings.use_docket = True
212224

@@ -226,12 +238,10 @@ async def execution():
226238

227239
bg_tasks = HybridBackgroundTasks()
228240

229-
async def run_test():
230-
await bg_tasks.add_task(lambda: None)
231-
232-
import asyncio
241+
bg_tasks.add_task(lambda: None)
233242

234-
asyncio.run(run_test())
243+
# Give the async task a moment to complete
244+
await asyncio.sleep(0.1)
235245

236246
# Verify Docket was initialized with correct settings
237247
mock_docket_class.assert_called_once_with(

0 commit comments

Comments
 (0)