Skip to content

Commit 00e7db4

Browse files
authored
Merge pull request #59 from redis/feature/allow-configuring-memory
Add configurable memory strategies with security validation
2 parents 7580e89 + 93cc5e6 commit 00e7db4

19 files changed

+2807
-53
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,4 @@ libs/redis/docs/.Trash*
233233
*.pyc
234234
.ai
235235
.claude
236+
TASK_MEMORY.md

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ A memory layer for AI agents using Redis as the vector database.
66

77
- **Dual Interface**: REST API and Model Context Protocol (MCP) server
88
- **Two-Tier Memory**: Working memory (session-scoped) and long-term memory (persistent)
9+
- **Configurable Memory Strategies**: Customize how memories are extracted (discrete, summary, preferences, custom)
910
- **Semantic Search**: Vector-based similarity search with metadata filtering
1011
- **Flexible Backends**: Pluggable vector store factory system
1112
- **AI Integration**: Automatic topic extraction, entity recognition, and conversation summarization

agent_memory_server/docket_tasks.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
from docket import Docket
88

99
from agent_memory_server.config import settings
10-
from agent_memory_server.extraction import extract_discrete_memories
10+
from agent_memory_server.extraction import (
11+
extract_discrete_memories,
12+
extract_memories_with_strategy,
13+
)
1114
from agent_memory_server.long_term_memory import (
1215
compact_long_term_memories,
1316
delete_long_term_memories,
@@ -31,6 +34,7 @@
3134
index_long_term_memories,
3235
compact_long_term_memories,
3336
extract_discrete_memories,
37+
extract_memories_with_strategy,
3438
promote_working_memory_to_long_term,
3539
delete_long_term_memories,
3640
forget_long_term_memories,

agent_memory_server/extraction.py

Lines changed: 128 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
1010

1111
from agent_memory_server.config import settings
12-
from agent_memory_server.filters import DiscreteMemoryExtracted
12+
from agent_memory_server.filters import DiscreteMemoryExtracted, MemoryType
1313
from agent_memory_server.llms import (
1414
AnthropicClientWrapper,
1515
OpenAIClientWrapper,
@@ -312,8 +312,8 @@ async def extract_discrete_memories(
312312
client = await get_model_client(settings.generation_model)
313313

314314
# Use vectorstore adapter to find messages that need discrete memory extraction
315-
# TODO: Sort out circular imports
316-
from agent_memory_server.filters import MemoryType
315+
# Local imports to avoid circular dependencies:
316+
# long_term_memory imports from extraction, so we import locally here
317317
from agent_memory_server.long_term_memory import index_long_term_memories
318318
from agent_memory_server.vectorstore_factory import get_vectorstore_adapter
319319

@@ -408,3 +408,128 @@ async def extract_discrete_memories(
408408
long_term_memories,
409409
deduplicate=deduplicate,
410410
)
411+
412+
413+
async def extract_memories_with_strategy(
414+
memories: list[MemoryRecord] | None = None,
415+
deduplicate: bool = True,
416+
):
417+
"""
418+
Extract memories using their configured strategies.
419+
420+
This function replaces extract_discrete_memories for strategy-aware extraction.
421+
Each memory record contains its extraction strategy configuration.
422+
"""
423+
# Local imports to avoid circular dependencies:
424+
# long_term_memory imports from extraction, so we import locally here
425+
from agent_memory_server.long_term_memory import index_long_term_memories
426+
from agent_memory_server.memory_strategies import get_memory_strategy
427+
from agent_memory_server.vectorstore_factory import get_vectorstore_adapter
428+
429+
adapter = await get_vectorstore_adapter()
430+
431+
if not memories:
432+
# If no memories are provided, search for any messages in long-term memory
433+
# that haven't been processed for extraction
434+
memories = []
435+
offset = 0
436+
while True:
437+
search_result = await adapter.search_memories(
438+
query="", # Empty query to get all messages
439+
memory_type=MemoryType(eq="message"),
440+
discrete_memory_extracted=DiscreteMemoryExtracted(eq="f"),
441+
limit=25,
442+
offset=offset,
443+
)
444+
445+
logger.info(
446+
f"Found {len(search_result.memories)} memories to extract: {[m.id for m in search_result.memories]}"
447+
)
448+
449+
memories += search_result.memories
450+
451+
if len(search_result.memories) < 25:
452+
break
453+
454+
offset += 25
455+
456+
# Group memories by extraction strategy for batch processing
457+
strategy_groups = {}
458+
for memory in memories:
459+
if not memory or not memory.text:
460+
logger.info(f"Deleting memory with no text: {memory}")
461+
await adapter.delete_memories([memory.id])
462+
continue
463+
464+
strategy_key = (
465+
memory.extraction_strategy,
466+
tuple(sorted(memory.extraction_strategy_config.items())),
467+
)
468+
if strategy_key not in strategy_groups:
469+
strategy_groups[strategy_key] = []
470+
strategy_groups[strategy_key].append(memory)
471+
472+
all_new_memories = []
473+
all_updated_memories = []
474+
475+
# Process each strategy group
476+
for (strategy_name, config_items), strategy_memories in strategy_groups.items():
477+
logger.info(
478+
f"Processing {len(strategy_memories)} memories with strategy: {strategy_name}"
479+
)
480+
481+
# Get strategy instance
482+
config_dict = dict(config_items)
483+
try:
484+
strategy = get_memory_strategy(strategy_name, **config_dict)
485+
except ValueError as e:
486+
logger.error(f"Unknown strategy {strategy_name}: {e}")
487+
# Fall back to discrete strategy
488+
strategy = get_memory_strategy("discrete")
489+
490+
# Process memories with this strategy
491+
for memory in strategy_memories:
492+
try:
493+
extracted_memories = await strategy.extract_memories(memory.text)
494+
all_new_memories.extend(extracted_memories)
495+
496+
# Update the memory to mark it as processed
497+
updated_memory = memory.model_copy(
498+
update={"discrete_memory_extracted": "t"}
499+
)
500+
all_updated_memories.append(updated_memory)
501+
502+
except Exception as e:
503+
logger.error(
504+
f"Error extracting memory {memory.id} with strategy {strategy_name}: {e}"
505+
)
506+
# Still mark as processed to avoid infinite retry
507+
updated_memory = memory.model_copy(
508+
update={"discrete_memory_extracted": "t"}
509+
)
510+
all_updated_memories.append(updated_memory)
511+
512+
# Update processed memories
513+
if all_updated_memories:
514+
await adapter.update_memories(all_updated_memories)
515+
516+
# Index new extracted memories
517+
if all_new_memories:
518+
long_term_memories = [
519+
MemoryRecord(
520+
id=str(ulid.ULID()),
521+
text=new_memory["text"],
522+
memory_type=new_memory.get("type", "episodic"),
523+
topics=new_memory.get("topics", []),
524+
entities=new_memory.get("entities", []),
525+
discrete_memory_extracted="t",
526+
extraction_strategy="discrete", # These are already extracted
527+
extraction_strategy_config={},
528+
)
529+
for new_memory in all_new_memories
530+
]
531+
532+
await index_long_term_memories(
533+
long_term_memories,
534+
deduplicate=deduplicate,
535+
)

agent_memory_server/long_term_memory.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212

1313
from agent_memory_server.config import settings
1414
from agent_memory_server.dependencies import get_background_tasks
15-
from agent_memory_server.extraction import extract_discrete_memories, handle_extraction
15+
from agent_memory_server.extraction import (
16+
extract_memories_with_strategy,
17+
handle_extraction,
18+
)
1619
from agent_memory_server.filters import (
1720
CreatedAt,
1821
Entities,
@@ -848,7 +851,7 @@ async def index_long_term_memories(
848851
# them as separate long-term memory records. This process also
849852
# runs deduplication if requested.
850853
await background_tasks.add_task(
851-
extract_discrete_memories,
854+
extract_memories_with_strategy,
852855
memories=needs_extraction,
853856
deduplicate=deduplicate,
854857
)
@@ -1372,6 +1375,14 @@ async def promote_working_memory_to_long_term(
13721375
current_memory = deduped_memory or memory
13731376
current_memory.persisted_at = datetime.now(UTC)
13741377

1378+
# Set extraction strategy configuration from working memory
1379+
current_memory.extraction_strategy = (
1380+
current_working_memory.long_term_memory_strategy.strategy
1381+
)
1382+
current_memory.extraction_strategy_config = (
1383+
current_working_memory.long_term_memory_strategy.config
1384+
)
1385+
13751386
# Index the memory in long-term storage
13761387
await index_long_term_memories(
13771388
[current_memory],
@@ -1434,6 +1445,14 @@ async def promote_working_memory_to_long_term(
14341445
current_memory = deduped_memory or memory_record
14351446
current_memory.persisted_at = datetime.now(UTC)
14361447

1448+
# Set extraction strategy configuration from working memory
1449+
current_memory.extraction_strategy = (
1450+
current_working_memory.long_term_memory_strategy.strategy
1451+
)
1452+
current_memory.extraction_strategy_config = (
1453+
current_working_memory.long_term_memory_strategy.config
1454+
)
1455+
14371456
# Collect memory record for batch indexing
14381457
message_records_to_index.append(current_memory)
14391458

0 commit comments

Comments
 (0)