|
9 | 9 | from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
|
10 | 10 |
|
11 | 11 | from agent_memory_server.config import settings
|
12 |
| -from agent_memory_server.filters import DiscreteMemoryExtracted |
| 12 | +from agent_memory_server.filters import DiscreteMemoryExtracted, MemoryType |
13 | 13 | from agent_memory_server.llms import (
|
14 | 14 | AnthropicClientWrapper,
|
15 | 15 | OpenAIClientWrapper,
|
@@ -312,8 +312,8 @@ async def extract_discrete_memories(
|
312 | 312 | client = await get_model_client(settings.generation_model)
|
313 | 313 |
|
314 | 314 | # Use vectorstore adapter to find messages that need discrete memory extraction
|
315 |
| - # TODO: Sort out circular imports |
316 |
| - from agent_memory_server.filters import MemoryType |
| 315 | + # Local imports to avoid circular dependencies: |
| 316 | + # long_term_memory imports from extraction, so we import locally here |
317 | 317 | from agent_memory_server.long_term_memory import index_long_term_memories
|
318 | 318 | from agent_memory_server.vectorstore_factory import get_vectorstore_adapter
|
319 | 319 |
|
@@ -408,3 +408,128 @@ async def extract_discrete_memories(
|
408 | 408 | long_term_memories,
|
409 | 409 | deduplicate=deduplicate,
|
410 | 410 | )
|
| 411 | + |
| 412 | + |
| 413 | +async def extract_memories_with_strategy( |
| 414 | + memories: list[MemoryRecord] | None = None, |
| 415 | + deduplicate: bool = True, |
| 416 | +): |
| 417 | + """ |
| 418 | + Extract memories using their configured strategies. |
| 419 | +
|
| 420 | + This function replaces extract_discrete_memories for strategy-aware extraction. |
| 421 | + Each memory record contains its extraction strategy configuration. |
| 422 | + """ |
| 423 | + # Local imports to avoid circular dependencies: |
| 424 | + # long_term_memory imports from extraction, so we import locally here |
| 425 | + from agent_memory_server.long_term_memory import index_long_term_memories |
| 426 | + from agent_memory_server.memory_strategies import get_memory_strategy |
| 427 | + from agent_memory_server.vectorstore_factory import get_vectorstore_adapter |
| 428 | + |
| 429 | + adapter = await get_vectorstore_adapter() |
| 430 | + |
| 431 | + if not memories: |
| 432 | + # If no memories are provided, search for any messages in long-term memory |
| 433 | + # that haven't been processed for extraction |
| 434 | + memories = [] |
| 435 | + offset = 0 |
| 436 | + while True: |
| 437 | + search_result = await adapter.search_memories( |
| 438 | + query="", # Empty query to get all messages |
| 439 | + memory_type=MemoryType(eq="message"), |
| 440 | + discrete_memory_extracted=DiscreteMemoryExtracted(eq="f"), |
| 441 | + limit=25, |
| 442 | + offset=offset, |
| 443 | + ) |
| 444 | + |
| 445 | + logger.info( |
| 446 | + f"Found {len(search_result.memories)} memories to extract: {[m.id for m in search_result.memories]}" |
| 447 | + ) |
| 448 | + |
| 449 | + memories += search_result.memories |
| 450 | + |
| 451 | + if len(search_result.memories) < 25: |
| 452 | + break |
| 453 | + |
| 454 | + offset += 25 |
| 455 | + |
| 456 | + # Group memories by extraction strategy for batch processing |
| 457 | + strategy_groups = {} |
| 458 | + for memory in memories: |
| 459 | + if not memory or not memory.text: |
| 460 | + logger.info(f"Deleting memory with no text: {memory}") |
| 461 | + await adapter.delete_memories([memory.id]) |
| 462 | + continue |
| 463 | + |
| 464 | + strategy_key = ( |
| 465 | + memory.extraction_strategy, |
| 466 | + tuple(sorted(memory.extraction_strategy_config.items())), |
| 467 | + ) |
| 468 | + if strategy_key not in strategy_groups: |
| 469 | + strategy_groups[strategy_key] = [] |
| 470 | + strategy_groups[strategy_key].append(memory) |
| 471 | + |
| 472 | + all_new_memories = [] |
| 473 | + all_updated_memories = [] |
| 474 | + |
| 475 | + # Process each strategy group |
| 476 | + for (strategy_name, config_items), strategy_memories in strategy_groups.items(): |
| 477 | + logger.info( |
| 478 | + f"Processing {len(strategy_memories)} memories with strategy: {strategy_name}" |
| 479 | + ) |
| 480 | + |
| 481 | + # Get strategy instance |
| 482 | + config_dict = dict(config_items) |
| 483 | + try: |
| 484 | + strategy = get_memory_strategy(strategy_name, **config_dict) |
| 485 | + except ValueError as e: |
| 486 | + logger.error(f"Unknown strategy {strategy_name}: {e}") |
| 487 | + # Fall back to discrete strategy |
| 488 | + strategy = get_memory_strategy("discrete") |
| 489 | + |
| 490 | + # Process memories with this strategy |
| 491 | + for memory in strategy_memories: |
| 492 | + try: |
| 493 | + extracted_memories = await strategy.extract_memories(memory.text) |
| 494 | + all_new_memories.extend(extracted_memories) |
| 495 | + |
| 496 | + # Update the memory to mark it as processed |
| 497 | + updated_memory = memory.model_copy( |
| 498 | + update={"discrete_memory_extracted": "t"} |
| 499 | + ) |
| 500 | + all_updated_memories.append(updated_memory) |
| 501 | + |
| 502 | + except Exception as e: |
| 503 | + logger.error( |
| 504 | + f"Error extracting memory {memory.id} with strategy {strategy_name}: {e}" |
| 505 | + ) |
| 506 | + # Still mark as processed to avoid infinite retry |
| 507 | + updated_memory = memory.model_copy( |
| 508 | + update={"discrete_memory_extracted": "t"} |
| 509 | + ) |
| 510 | + all_updated_memories.append(updated_memory) |
| 511 | + |
| 512 | + # Update processed memories |
| 513 | + if all_updated_memories: |
| 514 | + await adapter.update_memories(all_updated_memories) |
| 515 | + |
| 516 | + # Index new extracted memories |
| 517 | + if all_new_memories: |
| 518 | + long_term_memories = [ |
| 519 | + MemoryRecord( |
| 520 | + id=str(ulid.ULID()), |
| 521 | + text=new_memory["text"], |
| 522 | + memory_type=new_memory.get("type", "episodic"), |
| 523 | + topics=new_memory.get("topics", []), |
| 524 | + entities=new_memory.get("entities", []), |
| 525 | + discrete_memory_extracted="t", |
| 526 | + extraction_strategy="discrete", # These are already extracted |
| 527 | + extraction_strategy_config={}, |
| 528 | + ) |
| 529 | + for new_memory in all_new_memories |
| 530 | + ] |
| 531 | + |
| 532 | + await index_long_term_memories( |
| 533 | + long_term_memories, |
| 534 | + deduplicate=deduplicate, |
| 535 | + ) |
0 commit comments