Skip to content

Commit 745ec70

Browse files
committed
fix: Accept single memory object in create_long_term_memory tool
LLMs often try to pass a single memory object instead of an array when calling create_long_term_memory. This change makes the tool more flexible by accepting both formats: - Single object: {text: '...', memory_type: '...'} - Array: [{text: '...', memory_type: '...'}] When a single object is passed, it's automatically wrapped in an array before being sent to the API. Changes: - Updated function signature to accept dict | list[dict] - Added automatic wrapping logic for single objects - Improved description to clarify both formats are accepted - Added test case for single object handling This fixes the issue where LLMs would get errors when trying to create a single memory at a time.
1 parent 18f1e25 commit 745ec70

File tree

4 files changed

+166
-3
lines changed

4 files changed

+166
-3
lines changed

agent-memory-client/agent_memory_client/integrations/langchain.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def get_memory_tools(
158158
},
159159
"create_long_term_memory": {
160160
"name": "create_long_term_memory",
161-
"description": "Create long-term memories directly for immediate storage and retrieval. Use this for important information that should be permanently stored without going through working memory.",
161+
"description": "Create long-term memories directly for immediate storage and retrieval. Use this for important information that should be permanently stored without going through working memory. You can pass a single memory object or a list of memory objects. Each memory needs: text (string), memory_type ('episodic' or 'semantic'), and optionally topics (list), entities (list), event_date (ISO string).",
162162
"func": _create_create_long_term_memory_func(
163163
memory_client, namespace, user_id
164164
),
@@ -345,8 +345,18 @@ def _create_create_long_term_memory_func(
345345
) -> Any:
346346
"""Create create_long_term_memory function."""
347347

348-
async def create_long_term_memory(memories: list[dict[str, Any]]) -> str:
349-
"""Create long-term memories directly for immediate storage."""
348+
async def create_long_term_memory(
349+
memories: list[dict[str, Any]] | dict[str, Any],
350+
) -> str:
351+
"""Create long-term memories directly for immediate storage.
352+
353+
Accepts either a single memory object or a list of memory objects.
354+
Each memory should have: text, memory_type, and optionally topics, entities, event_date.
355+
"""
356+
# Handle single memory object - wrap it in a list
357+
if isinstance(memories, dict):
358+
memories = [memories]
359+
350360
result = await client.resolve_function_call(
351361
function_name="create_long_term_memory",
352362
function_arguments={"memories": memories},

agent-memory-client/tests/test_langchain_integration.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,42 @@ async def test_create_long_term_memory_tool_execution(self):
245245
assert call_kwargs["function_name"] == "create_long_term_memory"
246246
assert len(call_kwargs["function_arguments"]["memories"]) == 2
247247

248+
@pytest.mark.skipif(not _langchain_available(), reason="LangChain not installed")
249+
@pytest.mark.asyncio
250+
async def test_create_long_term_memory_accepts_single_object(self):
251+
"""Test that create_long_term_memory accepts a single memory object (not just array)."""
252+
from agent_memory_client.integrations.langchain import get_memory_tools
253+
254+
client = _create_mock_client()
255+
client.resolve_function_call = AsyncMock(
256+
return_value={
257+
"success": True,
258+
"formatted_response": "Created 1 long-term memory",
259+
}
260+
)
261+
262+
tools = get_memory_tools(
263+
memory_client=client,
264+
session_id="test_session",
265+
user_id="test_user",
266+
tools=["create_long_term_memory"],
267+
)
268+
269+
create_tool = tools[0]
270+
271+
# Execute with single object (what LLMs often do)
272+
single_memory = {"text": "User loves pizza", "memory_type": "semantic"}
273+
result = await create_tool.ainvoke({"memories": single_memory})
274+
275+
assert "Created 1 long-term memory" in result
276+
277+
# Verify it was converted to an array
278+
call_args = client.resolve_function_call.call_args
279+
memories_arg = call_args.kwargs["function_arguments"]["memories"]
280+
assert isinstance(memories_arg, list)
281+
assert len(memories_arg) == 1
282+
assert memories_arg[0] == single_memory
283+
248284
@pytest.mark.skipif(not _langchain_available(), reason="LangChain not installed")
249285
def test_tool_has_correct_schema(self):
250286
"""Test that generated tools have correct schemas."""

check_schema.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Check the actual schema generated by LangChain for create_long_term_memory."""
2+
3+
import asyncio
4+
import json
5+
6+
from agent_memory_client import MemoryAPIClient, MemoryClientConfig
7+
from agent_memory_client.integrations.langchain import get_memory_tools
8+
9+
10+
async def main():
11+
# Initialize memory client without health check
12+
config = MemoryClientConfig(base_url="http://localhost:8000")
13+
memory_client = MemoryAPIClient(config)
14+
15+
# Get memory tools
16+
tools = get_memory_tools(
17+
memory_client=memory_client,
18+
session_id="test_session",
19+
user_id="alice",
20+
tools=["create_long_term_memory", "add_memory_to_working_memory"],
21+
)
22+
23+
# Check both tools
24+
for tool in tools:
25+
print("=" * 80)
26+
print("Tool Name:", tool.name)
27+
print("\nTool Description:")
28+
print(tool.description)
29+
print("\nTool Args Schema:")
30+
print(json.dumps(tool.args_schema.schema(), indent=2))
31+
print()
32+
33+
34+
if __name__ == "__main__":
35+
asyncio.run(main())

test_single_memory.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Test that create_long_term_memory accepts both single object and array."""
2+
3+
import asyncio
4+
from unittest.mock import AsyncMock
5+
6+
from agent_memory_client import MemoryAPIClient, MemoryClientConfig
7+
from agent_memory_client.integrations.langchain import get_memory_tools
8+
9+
10+
async def main():
11+
# Create mock client
12+
config = MemoryClientConfig(base_url="http://localhost:8000")
13+
client = MemoryAPIClient(config)
14+
15+
# Mock the resolve_function_call method
16+
client.resolve_function_call = AsyncMock(
17+
return_value={
18+
"success": True,
19+
"formatted_response": "Created 1 memory successfully",
20+
}
21+
)
22+
23+
# Get the tool
24+
tools = get_memory_tools(
25+
memory_client=client,
26+
session_id="test",
27+
user_id="alice",
28+
tools=["create_long_term_memory"],
29+
)
30+
31+
create_tool = tools[0]
32+
33+
# Test 1: Single memory object (what LLMs often do)
34+
print("Test 1: Single memory object")
35+
single_memory = {
36+
"text": "User loves pizza",
37+
"memory_type": "semantic",
38+
"topics": ["food", "preferences"],
39+
}
40+
41+
result = await create_tool.func(memories=single_memory)
42+
print(f"Result: {result}")
43+
44+
# Check what was actually called
45+
call_args = client.resolve_function_call.call_args
46+
print(f"Called with: {call_args}")
47+
memories_arg = call_args.kwargs["function_arguments"]["memories"]
48+
print(f"Memories argument type: {type(memories_arg)}")
49+
print(f"Memories argument value: {memories_arg}")
50+
assert isinstance(memories_arg, list), "Should convert single object to list"
51+
assert len(memories_arg) == 1, "Should have one memory"
52+
print("✅ Single object correctly converted to array\n")
53+
54+
# Test 2: Array of memories (correct format)
55+
print("Test 2: Array of memories")
56+
client.resolve_function_call.reset_mock()
57+
58+
memory_array = [
59+
{"text": "User loves pizza", "memory_type": "semantic", "topics": ["food"]},
60+
{
61+
"text": "User works at TechCorp",
62+
"memory_type": "semantic",
63+
"topics": ["work"],
64+
},
65+
]
66+
67+
result = await create_tool.func(memories=memory_array)
68+
print(f"Result: {result}")
69+
70+
call_args = client.resolve_function_call.call_args
71+
memories_arg = call_args.kwargs["function_arguments"]["memories"]
72+
print(f"Memories argument type: {type(memories_arg)}")
73+
print(f"Memories argument value: {memories_arg}")
74+
assert isinstance(memories_arg, list), "Should keep as list"
75+
assert len(memories_arg) == 2, "Should have two memories"
76+
print("✅ Array correctly passed through\n")
77+
78+
print("All tests passed! ✅")
79+
80+
81+
if __name__ == "__main__":
82+
asyncio.run(main())

0 commit comments

Comments
 (0)