diff --git a/CLAUDE.md b/CLAUDE.md index 6a38f34..6953d65 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,5 +1,9 @@ # CLAUDE.md - Redis Agent Memory Server Project Context +## Redis Version +This project uses Redis 8, which is the redis:8 docker image. +Do not use Redis Stack or other earlier versions of Redis. + ## Frequently Used Commands Get started in a new environment by installing `uv`: ```bash @@ -188,7 +192,6 @@ EMBEDDING_MODEL=text-embedding-3-small # Memory Configuration LONG_TERM_MEMORY=true -WINDOW_SIZE=20 ENABLE_TOPIC_EXTRACTION=true ENABLE_NER=true ``` diff --git a/agent-memory-client/agent_memory_client/client.py b/agent-memory-client/agent_memory_client/client.py index adc58aa..2eb3ca6 100644 --- a/agent-memory-client/agent_memory_client/client.py +++ b/agent-memory-client/agent_memory_client/client.py @@ -209,7 +209,6 @@ async def get_working_memory( session_id: str, user_id: str | None = None, namespace: str | None = None, - window_size: int | None = None, model_name: ModelNameLiteral | None = None, context_window_max: int | None = None, ) -> WorkingMemoryResponse: @@ -220,7 +219,6 @@ async def get_working_memory( session_id: The session ID to retrieve working memory for user_id: The user ID to retrieve working memory for namespace: Optional namespace for the session - window_size: Optional number of messages to include model_name: Optional model name to determine context window size context_window_max: Optional direct specification of context window tokens @@ -241,9 +239,6 @@ async def get_working_memory( elif self.config.default_namespace is not None: params["namespace"] = self.config.default_namespace - if window_size is not None: - params["window_size"] = str(window_size) - # Use provided model_name or fall back to config default effective_model_name = model_name or self.config.default_model_name if effective_model_name is not None: @@ -2139,7 +2134,6 @@ async def memory_prompt( query: str, session_id: str | None = None, namespace: str | None = None, - window_size: int | None = None, model_name: str | None = None, context_window_max: int | None = None, long_term_search: dict[str, Any] | None = None, @@ -2154,7 +2148,6 @@ async def memory_prompt( query: The input text to find relevant context for session_id: Optional session ID to include session messages namespace: Optional namespace for the session - window_size: Optional number of messages to include model_name: Optional model name to determine context window size context_window_max: Optional direct specification of context window tokens long_term_search: Optional search parameters for long-term memory @@ -2169,7 +2162,6 @@ async def memory_prompt( prompt = await client.memory_prompt( query="What are my UI preferences?", session_id="current_session", - window_size=10, long_term_search={ "topics": {"any": ["preferences", "ui"]}, "limit": 5 @@ -2190,8 +2182,6 @@ async def memory_prompt( session_params["namespace"] = namespace elif self.config.default_namespace is not None: session_params["namespace"] = self.config.default_namespace - if window_size is not None: - session_params["window_size"] = str(window_size) # Use provided model_name or fall back to config default effective_model_name = model_name or self.config.default_model_name if effective_model_name is not None: diff --git a/agent_memory_server/api.py b/agent_memory_server/api.py index 578795d..a16efad 100644 --- a/agent_memory_server/api.py +++ b/agent_memory_server/api.py @@ -582,19 +582,16 @@ async def memory_prompt( logger.debug(f"Memory prompt params: {params}") if params.session: - # Use token limit for memory prompt, fallback to message count for backward compatibility + # Use token limit for memory prompt - model info is required now if params.session.model_name or params.session.context_window_max: token_limit = _get_effective_token_limit( model_name=params.session.model_name, context_window_max=params.session.context_window_max, ) - effective_window_size = ( - token_limit # We'll handle token-based truncation below - ) + effective_token_limit = token_limit else: - effective_window_size = ( - params.session.window_size - ) # Fallback to message count + # No model info provided - use all messages without truncation + effective_token_limit = None working_mem = await working_memory.get_working_memory( session_id=params.session.session_id, namespace=params.session.namespace, @@ -616,11 +613,11 @@ async def memory_prompt( ) ) # Apply token-based truncation if model info is provided - if params.session.model_name or params.session.context_window_max: + if effective_token_limit is not None: # Token-based truncation if ( _calculate_messages_token_count(working_mem.messages) - > effective_window_size + > effective_token_limit ): # Keep removing oldest messages until we're under the limit recent_messages = working_mem.messages[:] @@ -628,34 +625,30 @@ async def memory_prompt( recent_messages = recent_messages[1:] # Remove oldest if ( _calculate_messages_token_count(recent_messages) - <= effective_window_size + <= effective_token_limit ): break else: recent_messages = working_mem.messages - - for msg in recent_messages: - if msg.role == "user": - msg_class = base.UserMessage - else: - msg_class = base.AssistantMessage - _messages.append( - msg_class( - content=TextContent(type="text", text=msg.content), - ) - ) else: - # No token-based truncation - use all messages - for msg in working_mem.messages: - if msg.role == "user": - msg_class = base.UserMessage - else: - msg_class = base.AssistantMessage - _messages.append( - msg_class( - content=TextContent(type="text", text=msg.content), - ) + # No token limit provided - use all messages + recent_messages = working_mem.messages + + for msg in recent_messages: + if msg.role == "user": + msg_class = base.UserMessage + elif msg.role == "assistant": + msg_class = base.AssistantMessage + else: + # For tool messages or other roles, treat as assistant for MCP compatibility + # since MCP base only supports UserMessage and AssistantMessage + msg_class = base.AssistantMessage + + _messages.append( + msg_class( + content=TextContent(type="text", text=msg.content), ) + ) if params.long_term_search: logger.debug( diff --git a/agent_memory_server/config.py b/agent_memory_server/config.py index 73acbb2..4ee837d 100644 --- a/agent_memory_server/config.py +++ b/agent_memory_server/config.py @@ -52,7 +52,7 @@ class Settings(BaseSettings): long_term_memory: bool = True openai_api_key: str | None = None anthropic_api_key: str | None = None - generation_model: str = "gpt-4o-mini" + generation_model: str = "gpt-4o" embedding_model: str = "text-embedding-3-small" port: int = 8000 mcp_port: int = 9000 @@ -118,7 +118,6 @@ class Settings(BaseSettings): auth0_client_secret: str | None = None # Working memory settings - window_size: int = 20 # Default number of recent messages to return summarization_threshold: float = ( 0.7 # Fraction of context window that triggers summarization ) diff --git a/agent_memory_server/long_term_memory.py b/agent_memory_server/long_term_memory.py index c886513..a48ad4e 100644 --- a/agent_memory_server/long_term_memory.py +++ b/agent_memory_server/long_term_memory.py @@ -324,21 +324,37 @@ async def compact_long_term_memories( index_name = Keys.search_index_name() # Create aggregation query to group by memory_hash and find duplicates - agg_query = ( - f"FT.AGGREGATE {index_name} {filter_str} " - "GROUPBY 1 @memory_hash " - "REDUCE COUNT 0 AS count " - 'FILTER "@count>1" ' # Only groups with more than 1 memory - "SORTBY 2 @count DESC " - f"LIMIT 0 {limit}" - ) + agg_query = [ + "FT.AGGREGATE", + index_name, + filter_str, + "GROUPBY", + str(1), + "@memory_hash", + "REDUCE", + "COUNT", + str(0), + "AS", + "count", + "FILTER", + "@count>1", # Only groups with more than 1 memory + "SORTBY", + str(2), + "@count", + "DESC", + "LIMIT", + str(0), + str(limit), + ] # Execute aggregation to find duplicate groups - duplicate_groups = await redis_client.execute_command(agg_query) + duplicate_groups = await redis_client.execute_command(*agg_query) if duplicate_groups and duplicate_groups[0] > 0: num_groups = duplicate_groups[0] - logger.info(f"Found {num_groups} groups of hash-based duplicates") + logger.info( + f"Found {num_groups} groups with hash-based duplicates to process" + ) # Process each group of duplicates for i in range(1, len(duplicate_groups), 2): @@ -423,9 +439,11 @@ async def compact_long_term_memories( ) except Exception as e: logger.error(f"Error processing duplicate group: {e}") + else: + logger.info("No hash-based duplicates found") logger.info( - f"Completed hash-based deduplication. Merged {memories_merged} memories." + f"Completed hash-based deduplication. Removed {memories_merged} duplicate memories." ) except Exception as e: logger.error(f"Error during hash-based duplicate compaction: {e}") diff --git a/agent_memory_server/main.py b/agent_memory_server/main.py index c2f4b66..bb4a715 100644 --- a/agent_memory_server/main.py +++ b/agent_memory_server/main.py @@ -135,7 +135,6 @@ async def lifespan(app: FastAPI): logger.info( "Redis Agent Memory Server initialized", - window_size=settings.window_size, generation_model=settings.generation_model, embedding_model=settings.embedding_model, long_term_memory=settings.long_term_memory, diff --git a/agent_memory_server/mcp.py b/agent_memory_server/mcp.py index 6ca1dc3..67d9216 100644 --- a/agent_memory_server/mcp.py +++ b/agent_memory_server/mcp.py @@ -467,7 +467,6 @@ async def memory_prompt( query: str, session_id: SessionId | None = None, namespace: Namespace | None = None, - window_size: int = settings.window_size, model_name: ModelNameLiteral | None = None, context_window_max: int | None = None, topics: Topics | None = None, @@ -579,7 +578,6 @@ async def memory_prompt( session_id=_session_id, namespace=namespace.eq if namespace and namespace.eq else None, user_id=user_id.eq if user_id and user_id.eq else None, - window_size=window_size, model_name=model_name, context_window_max=context_window_max, ) diff --git a/agent_memory_server/models.py b/agent_memory_server/models.py index 204dfdf..a7c84f1 100644 --- a/agent_memory_server/models.py +++ b/agent_memory_server/models.py @@ -7,7 +7,6 @@ from pydantic import BaseModel, Field from ulid import ULID -from agent_memory_server.config import settings from agent_memory_server.filters import ( CreatedAt, Entities, @@ -238,7 +237,6 @@ class WorkingMemoryRequest(BaseModel): session_id: str namespace: str | None = None user_id: str | None = None - window_size: int = settings.window_size model_name: ModelNameLiteral | None = None context_window_max: int | None = None diff --git a/agent_memory_server/pluggable-long-term-memory.md b/agent_memory_server/pluggable-long-term-memory.md deleted file mode 100644 index 4096ad7..0000000 --- a/agent_memory_server/pluggable-long-term-memory.md +++ /dev/null @@ -1,152 +0,0 @@ -## Feature: Pluggable Long-Term Memory via LangChain VectorStore Adapter - -**Summary:** -Refactor agent-memory-server's long-term memory component to use the [LangChain VectorStore interface](https://python.langchain.com/docs/integrations/vectorstores/) as its backend abstraction. -This will allow users to select from dozens of supported databases (Chroma, Pinecone, Weaviate, Redis, Qdrant, Milvus, Postgres/PGVector, LanceDB, and more) with minimal custom code. -The backend should be configurable at runtime via environment variables or config, and require no custom adapters for each new supported store. - -**Reference:** -- [agent-memory-server repo](https://github.com/redis-developer/agent-memory-server) -- [LangChain VectorStore docs](https://python.langchain.com/docs/integrations/vectorstores/) - ---- - -### Requirements - -1. **Adopt LangChain VectorStore as the Storage Interface** - - All long-term memory operations (`add`, `search`, `delete`, `update`) must delegate to a LangChain-compatible VectorStore instance. - - Avoid any database-specific code paths for core CRUD/search; rely on VectorStore's interface. - - The VectorStore instance must be initialized at server startup, using connection parameters from environment variables or config. - -2. **Backend Swappability** - - The backend type (e.g., Chroma, Pinecone, Redis, Postgres, etc.) must be selectable at runtime via a config variable (e.g., `LONG_TERM_MEMORY_BACKEND`). - - All required connection/config parameters for the backend should be loaded from environment/config. - - Adding new supported databases should require no new adapter code—just list them in documentation and config. - -3. **API Mapping and Model Translation** - - Ensure your memory API endpoints map directly to the underlying VectorStore methods (e.g., `add_texts`, `similarity_search`, `delete`). - - Translate between your internal MemoryRecord model and LangChain's `Document` (or other types as needed) at the service boundary. - - Support metadata storage and filtering as allowed by the backend; document any differences in filter syntax or capability. - -4. **Configuration and Documentation** - - Document all supported backends, their config options, and any installation requirements (e.g., which Python extras to install for each backend). - - Update `.env.example` with required variables for each backend type. - - Add a table in the README listing supported databases and any notable feature support/limitations (e.g., advanced filters, hybrid search). - -5. **Testing and CI** - - Add tests to verify core flows (add, search, delete, filter) work with at least two VectorStore backends (e.g., Chroma and Redis). - - (Optional) Use in-memory stores for unit tests where possible. - -6. **(Optional but Preferred) Dependency Handling** - - Optional dependencies for each backend should be installed only if required (using extras, e.g., `pip install agent-memory-server[chroma]`). - ---- - -### Implementation Steps - -1. **Create a Thin Adapter Layer** - - Implement a `VectorStoreMemoryAdapter` class that wraps a LangChain VectorStore instance and exposes memory operations. - - Adapter methods should map 1:1 to LangChain methods (e.g., `add_texts`, `similarity_search`, `delete`), translating data models as needed. - -2. **Backend Selection and Initialization** - - On startup, read `LONG_TERM_MEMORY_BACKEND` and associated connection params. - - Dynamically instantiate the appropriate VectorStore via LangChain, passing required config. - - Store the instance as a singleton/service to be used by API endpoints. - -3. **API Endpoint Refactor** - - Refactor long-term memory API endpoints to call adapter methods only; eliminate any backend-specific logic from the endpoints. - - Ensure filter syntax in your API is converted to the form expected by each VectorStore. Where not possible, document or gracefully reject unsupported filter types. - -4. **Update Documentation** - - Clearly explain backend selection, configuration, and how to install dependencies for each supported backend. - - Add usage examples for at least two backends (Chroma and Redis recommended). - - List any differences in filtering, advanced features, or limits by backend. - -5. **Testing** - - Add or update tests to cover core memory operations with at least two different VectorStore backends. - - Use environment variables or test config files to run tests with different backends in CI. - ---- - -### Acceptance Criteria - -- [x] agent-memory-server supports Redis backends for long-term memory, both selectable at runtime via config/env. -- [x] All long-term memory API operations are delegated through the LangChain VectorStore interface. -- [x] README documents backend selection, configuration, and installation for each supported backend. -- [x] Tests cover all core flows with at least two backends (Redis and Postgres). -- [x] No breaking changes to API or existing users by default. - ---- - -**See [LangChain VectorStore Integrations](https://python.langchain.com/docs/integrations/vectorstores/) for a full list of supported databases and client libraries.** - -## Progress of Development -Keep track of your progress building this feature here. - -### Analysis Phase (Complete) -- [x] **Read existing codebase** - Analyzed current Redis-based implementation in `long_term_memory.py` -- [x] **Understand current architecture** - Current system uses RedisVL with direct Redis connections -- [x] **Identify key components to refactor**: - - `search_long_term_memories()` - Main search function using RedisVL VectorQuery - - `index_long_term_memories()` - Memory indexing with Redis hash storage - - `count_long_term_memories()` - Count operations - - Redis utilities in `utils/redis.py` for connection management and index setup -- [x] **Understand data models** - MemoryRecord contains text, metadata (topics, entities, dates), and embeddings -- [x] **Review configuration** - Current Redis config in `config.py`, need to add backend selection - -### Implementation Plan -1. **Add LangChain dependencies and backend configuration** ✅ -2. **Create VectorStore adapter interface** ✅ -3. **Implement backend factory for different VectorStores** ✅ -4. **Refactor long-term memory functions to use adapter** ✅ -5. **Update API endpoints and add documentation** ✅ -6. **Add tests for multiple backends** ✅ - -### Current Status: Implementation Complete ✅ -- [x] **Added LangChain dependencies** - Added langchain-core and optional dependencies for all major vectorstore backends -- [x] **Extended configuration** - Added backend selection and connection parameters for all supported backends -- [x] **Created VectorStoreAdapter interface** - Abstract base class with methods for add/search/delete/count operations -- [x] **Implemented LangChainVectorStoreAdapter** - Generic adapter that works with any LangChain VectorStore -- [x] **Created VectorStore factory** - Factory functions for all supported backends (Redis, Chroma, Pinecone, Weaviate, Qdrant, Milvus, PGVector, LanceDB, OpenSearch) -- [x] **Refactored core long-term memory functions** - `search_long_term_memories()`, `index_long_term_memories()`, and `count_long_term_memories()` now use the adapter -- [x] **Check and update API endpoints** - Ensure all memory API endpoints use the new adapter through the refactored functions -- [x] **Update environment configuration** - Add .env.example entries for all supported backends -- [x] **Create comprehensive documentation** - Document all supported backends, configuration options, and usage examples -- [x] **Add basic tests** - Created test suite for vectorstore adapter functionality -- [x] **Verified implementation** - All core functionality tested and working correctly - -## Summary - -✅ **FEATURE COMPLETE**: The pluggable long-term memory feature has been successfully implemented! - -The Redis Agent Memory Server now supports **9 different vector store backends** through the LangChain VectorStore interface: -- Redis (default), Chroma, Pinecone, Weaviate, Qdrant, Milvus, PostgreSQL/PGVector, LanceDB, and OpenSearch - -**Key Achievements:** -- ✅ **Zero breaking changes** - Existing Redis users continue to work without any changes -- ✅ **Runtime backend selection** - Set `LONG_TERM_MEMORY_BACKEND=` to switch -- ✅ **Unified API interface** - All backends work through the same API endpoints -- ✅ **Production ready** - Full error handling, logging, and documentation -- ✅ **Comprehensive documentation** - Complete setup guides for all backends -- ✅ **Verified functionality** - Core operations tested and working - -**Implementation Details:** -- **VectorStore Adapter Pattern** - Clean abstraction layer between memory server and LangChain VectorStores -- **Backend Factory** - Dynamic instantiation of vectorstore backends based on configuration -- **Metadata Handling** - Proper conversion between MemoryRecord and LangChain Document formats -- **Filtering Support** - Post-processing filters for complex queries (Redis native filtering disabled temporarily due to syntax complexity) -- **Error Handling** - Graceful fallbacks and comprehensive error logging - -**Testing Results:** -- ✅ **CRUD Operations** - Add, search, delete, and count operations working correctly -- ✅ **Semantic Search** - Vector similarity search with proper scoring -- ✅ **Metadata Filtering** - Session, user, namespace, topics, and entities filtering -- ✅ **Data Persistence** - Memories properly stored and retrieved -- ✅ **No Breaking Changes** - Existing functionality preserved - -**Next Steps for Future Development:** -- [ ] **Optimize Redis filtering** - Implement proper Redis JSON path filtering for better performance -- [ ] **Add proper error handling and logging** - Improve error messages for different backend failures -- [ ] **Create tests for multiple backends** - Test core functionality with Redis and at least one other backend -- [ ] **Performance benchmarking** - Compare performance across different backends -- [ ] **Migration tooling** - Tools to migrate data between backends diff --git a/agent_memory_server/summarization.py b/agent_memory_server/summarization.py index 6d90be3..bbffb9e 100644 --- a/agent_memory_server/summarization.py +++ b/agent_memory_server/summarization.py @@ -113,24 +113,20 @@ async def _incremental_summary( async def summarize_session( session_id: str, model: str, - window_size: int, + max_context_tokens: int | None = None, ) -> None: """ - Summarize messages in a session when they exceed the window size. + Summarize messages in a session when they exceed the token limit. This function: - 1. Gets the oldest messages up to window size and current context - 2. Generates a new summary that includes these messages + 1. Gets messages and current context + 2. Generates a new summary that includes older messages 3. Removes older, summarized messages and updates the context - Stop summarizing - Args: session_id: The session ID - model: The model to use - window_size: Maximum number of messages to keep - client: The client wrapper (OpenAI or Anthropic) - redis_conn: Redis connection + model: The model to use for summarization + max_context_tokens: Maximum context tokens to keep (defaults to model's context window * summarization_threshold) """ logger.debug(f"Summarizing session {session_id}") redis = await get_redis_conn() @@ -144,11 +140,11 @@ async def summarize_session( num_messages = await pipe.llen(messages_key) # type: ignore logger.debug(f"[summarization] Number of messages: {num_messages}") - if num_messages < window_size: + if num_messages < 2: # Need at least 2 messages to summarize logger.info(f"Not enough messages to summarize for session {session_id}") return - messages_raw = await pipe.lrange(messages_key, 0, window_size - 1) # type: ignore + messages_raw = await pipe.lrange(messages_key, 0, -1) # Get all messages metadata = await pipe.hgetall(metadata_key) # type: ignore pipe.multi() @@ -164,48 +160,89 @@ async def summarize_session( logger.debug(f"[summarization] Messages: {messages}") model_config = get_model_config(model) - max_tokens = model_config.max_tokens - - # Token allocation: - # - For small context (<10k): use 12.5% (min 512) - # - For medium context (10k-50k): use 10% (min 1024) - # - For large context (>50k): use 5% (min 2048) - if max_tokens < 10000: - summary_max_tokens = max(512, max_tokens // 8) # 12.5% - elif max_tokens < 50000: - summary_max_tokens = max(1024, max_tokens // 10) # 10% + full_context_tokens = model_config.max_tokens + + # Use provided max_context_tokens or calculate from model context window + if max_context_tokens is None: + max_context_tokens = int( + full_context_tokens * settings.summarization_threshold + ) + + # Calculate current token usage + encoding = tiktoken.get_encoding("cl100k_base") + current_tokens = sum( + len(encoding.encode(f"{msg.role}: {msg.content}")) + for msg in messages + ) + + # If we're under the limit, no need to summarize + if current_tokens <= max_context_tokens: + logger.info( + f"Messages under token limit ({current_tokens} <= {max_context_tokens}) for session {session_id}" + ) + return + + # Token allocation for summarization + if full_context_tokens < 10000: + summary_max_tokens = max(512, full_context_tokens // 8) # 12.5% + elif full_context_tokens < 50000: + summary_max_tokens = max(1024, full_context_tokens // 10) # 10% else: - summary_max_tokens = max(2048, max_tokens // 20) # 5% + summary_max_tokens = max(2048, full_context_tokens // 20) # 5% logger.debug( f"[summarization] Summary max tokens: {summary_max_tokens}" ) # Scale buffer tokens with context size, but keep reasonable bounds - buffer_tokens = min(max(230, max_tokens // 100), 1000) + buffer_tokens = min(max(230, full_context_tokens // 100), 1000) logger.debug(f"[summarization] Buffer tokens: {buffer_tokens}") - max_message_tokens = max_tokens - summary_max_tokens - buffer_tokens - encoding = tiktoken.get_encoding("cl100k_base") + max_message_tokens = ( + full_context_tokens - summary_max_tokens - buffer_tokens + ) + + # Determine how many messages to keep (target ~40% of max_context_tokens for recent messages) + target_remaining_tokens = int(max_context_tokens * 0.4) + + # Work backwards to find recent messages to keep + recent_messages_tokens = 0 + keep_count = 0 + + for i in range(len(messages) - 1, -1, -1): + msg = messages[i] + msg_str = f"{msg.role}: {msg.content}" + msg_tokens = len(encoding.encode(msg_str)) + + if recent_messages_tokens + msg_tokens <= target_remaining_tokens: + recent_messages_tokens += msg_tokens + keep_count += 1 + else: + break + + # Messages to summarize are the ones we're not keeping + messages_to_check = ( + messages[:-keep_count] if keep_count > 0 else messages[:-1] + ) + total_tokens = 0 messages_to_summarize = [] - for msg in messages: + for msg in messages_to_check: msg_str = f"{msg.role}: {msg.content}" msg_tokens = len(encoding.encode(msg_str)) - # TODO: Here, we take a partial message if a single message's - # total size exceeds the buffer. Should this be configurable - # behavior? + # Handle oversized messages if msg_tokens > max_message_tokens: msg_str = msg_str[: max_message_tokens // 2] msg_tokens = len(encoding.encode(msg_str)) - total_tokens += msg_tokens if total_tokens + msg_tokens <= max_message_tokens: total_tokens += msg_tokens messages_to_summarize.append(msg_str) + else: + break if not messages_to_summarize: logger.info(f"No messages to summarize for session {session_id}") @@ -227,9 +264,13 @@ async def summarize_session( pipe.hmset(metadata_key, mapping=metadata) logger.debug(f"[summarization] Metadata: {metadata_key} {metadata}") - # Messages that were summarized - num_summarized = len(messages_to_summarize) - pipe.ltrim(messages_key, 0, num_summarized - 1) + # Keep only the most recent messages that fit in our token budget + if keep_count > 0: + # Keep the last keep_count messages + pipe.ltrim(messages_key, -keep_count, -1) + else: + # Keep at least the last message + pipe.ltrim(messages_key, -1, -1) await pipe.execute() break diff --git a/docker-compose.yml b/docker-compose.yml index 74cdbef..1891032 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -13,7 +13,6 @@ services: - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY} # Optional configurations with defaults - LONG_TERM_MEMORY=True - - WINDOW_SIZE=20 - GENERATION_MODEL=gpt-4o-mini - EMBEDDING_MODEL=text-embedding-3-small - ENABLE_TOPIC_EXTRACTION=True @@ -38,13 +37,6 @@ services: # Add your API keys here or use a .env file - OPENAI_API_KEY=${OPENAI_API_KEY} - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY} - # Optional configurations with defaults - - LONG_TERM_MEMORY=True - - WINDOW_SIZE=20 - - GENERATION_MODEL=gpt-4o-mini - - EMBEDDING_MODEL=text-embedding-3-small - - ENABLE_TOPIC_EXTRACTION=True - - ENABLE_NER=True ports: - "9050:9000" depends_on: @@ -61,12 +53,6 @@ services: - OPENAI_API_KEY=${OPENAI_API_KEY} - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY} # Optional configurations with defaults - - LONG_TERM_MEMORY=True - - WINDOW_SIZE=20 - - GENERATION_MODEL=gpt-4o-mini - - EMBEDDING_MODEL=text-embedding-3-small - - ENABLE_TOPIC_EXTRACTION=True - - ENABLE_NER=True depends_on: - redis command: ["uv", "run", "agent-memory", "task-worker"] @@ -80,7 +66,7 @@ services: - "16380:6379" # Redis port volumes: - redis_data:/data - command: redis-server --save "" --loglevel warning --appendonly no --stop-writes-on-bgsave-error no + command: redis-server --save "30 1" --loglevel warning --appendonly no --stop-writes-on-bgsave-error no healthcheck: test: [ "CMD", "redis-cli", "ping" ] interval: 30s diff --git a/docs/api.md b/docs/api.md index aa576b3..b708fd1 100644 --- a/docs/api.md +++ b/docs/api.md @@ -24,13 +24,12 @@ The following endpoints are available: _Query Parameters:_ - `namespace` (string, optional): The namespace to use for the session - - `window_size` (int, optional): Number of messages to include in the response (default from config) - `model_name` (string, optional): The client's LLM model name to determine appropriate context window size - `context_window_max` (int, optional): Direct specification of max context window tokens (overrides model_name) - **PUT /v1/working-memory/{session_id}** Sets working memory for a session, replacing any existing memory. - Automatically summarizes conversations that exceed the window size. + Automatically summarizes conversations that exceed the token limit. _Request Body Example:_ ```json @@ -103,7 +102,8 @@ The following endpoints are available: "session": { "session_id": "session-123", "namespace": "default", - "window_size": 10 + "model_name": "gpt-4o", + "context_window_max": 4000 }, "long_term_search": { "text": "AI discussion", diff --git a/docs/memory-types.md b/docs/memory-types.md index 02bf30d..6ddd59e 100644 --- a/docs/memory-types.md +++ b/docs/memory-types.md @@ -86,7 +86,7 @@ Working memory contains: ```http # Get working memory for a session -GET /v1/working-memory/{session_id}?namespace=demo&window_size=50 +GET /v1/working-memory/{session_id}?namespace=demo&model_name=gpt-4o # Set working memory (replaces existing) PUT /v1/working-memory/{session_id} @@ -300,7 +300,8 @@ response = await memory_prompt({ "query": "Help me plan dinner", "session": { "session_id": "current_chat", - "window_size": 20 + "model_name": "gpt-4o", + "context_window_max": 4000 }, "long_term_search": { "text": "food preferences dietary restrictions", diff --git a/examples/memory_prompt_agent.py b/examples/memory_prompt_agent.py index 2aa665f..29e09f2 100644 --- a/examples/memory_prompt_agent.py +++ b/examples/memory_prompt_agent.py @@ -123,7 +123,7 @@ async def _get_memory_prompt( session_id=session_id, query=user_input, # Optional parameters to control memory retrieval - window_size=30, # Controls working memory messages + model_name="gpt-4o-mini", # Controls token-based truncation long_term_search={"limit": 30}, # Controls long-term memory limit user_id=user_id, ) diff --git a/examples/travel_agent.py b/examples/travel_agent.py index 52ba0ad..ccdde53 100644 --- a/examples/travel_agent.py +++ b/examples/travel_agent.py @@ -255,7 +255,7 @@ async def _get_working_memory(self, session_id: str, user_id: str) -> WorkingMem result = await client.get_working_memory( session_id=session_id, namespace=self._get_namespace(user_id), - window_size=15, + model_name="gpt-4o-mini", # Controls token-based truncation ) return WorkingMemory(**result.model_dump()) diff --git a/manual_oauth_qa/manual_auth0_test.py b/manual_oauth_qa/manual_auth0_test.py index 48b1328..8621011 100755 --- a/manual_oauth_qa/manual_auth0_test.py +++ b/manual_oauth_qa/manual_auth0_test.py @@ -180,7 +180,7 @@ def run_comprehensive_test(self): "session": { "session_id": "test-session-auth0", "namespace": "test-auth0", - "window_size": 10, + "model_name": "gpt-4o-mini", }, }, ), diff --git a/manual_oauth_qa/test_auth0.py b/manual_oauth_qa/test_auth0.py index 17ae874..d286ca3 100755 --- a/manual_oauth_qa/test_auth0.py +++ b/manual_oauth_qa/test_auth0.py @@ -180,7 +180,7 @@ def run_comprehensive_test(self): "session": { "session_id": "test-session-auth0", "namespace": "test-auth0", - "window_size": 10, + "model_name": "gpt-4o-mini", }, }, ), diff --git a/tests/test_api.py b/tests/test_api.py index c2417f7..f7fb129 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -260,7 +260,7 @@ async def test_post_memory_compacts_long_conversation( "namespace": "test-namespace", "session_id": "test-session", } - mock_settings = Settings(window_size=1, long_term_memory=False) + mock_settings = Settings(long_term_memory=False) with ( patch("agent_memory_server.api.settings", mock_settings), @@ -288,7 +288,7 @@ async def test_post_memory_compacts_long_conversation( # Should return the summarized working memory assert "messages" in data assert "context" in data - # Should have been summarized (only 1 message kept due to window_size=1) + # Should have been summarized (token-based summarization in _summarize_working_memory) assert len(data["messages"]) == 1 assert data["messages"][0]["content"] == "Hi there" assert "Summary:" in data["context"] @@ -400,7 +400,6 @@ async def test_memory_prompt_with_session_id(self, mock_get_working_memory, clie "session": { "session_id": "test-session", "namespace": "test-namespace", - "window_size": 10, "model_name": "gpt-4o", "context_window_max": 1000, }, diff --git a/tests/test_client_api.py b/tests/test_client_api.py index 1202419..8235652 100644 --- a/tests/test_client_api.py +++ b/tests/test_client_api.py @@ -303,7 +303,6 @@ async def test_memory_prompt(memory_test_client: MemoryAPIClient): query=query, session_id=session_id, namespace="test-namespace", - window_size=5, model_name="gpt-4o", context_window_max=4000, ) diff --git a/tests/test_full_integration.py b/tests/test_full_integration.py index 1c0761a..aa0ac6d 100644 --- a/tests/test_full_integration.py +++ b/tests/test_full_integration.py @@ -653,7 +653,6 @@ async def test_memory_prompt_with_working_memory( prompt_result = await client.memory_prompt( query="What programming language should I use?", session_id=unique_session_id, - window_size=4, ) assert "messages" in prompt_result diff --git a/tests/test_summarization.py b/tests/test_summarization.py index 65bf4e9..3c92c0e 100644 --- a/tests/test_summarization.py +++ b/tests/test_summarization.py @@ -88,19 +88,22 @@ async def test_summarize_session( """Test summarize_session with mocked summarization""" session_id = "test-session" model = "gpt-3.5-turbo" - window_size = 4 + max_context_tokens = 1000 pipeline_mock = MagicMock() # pipeline is not a coroutine pipeline_mock.__aenter__ = AsyncMock(return_value=pipeline_mock) pipeline_mock.watch = AsyncMock() mock_async_redis_client.pipeline = MagicMock(return_value=pipeline_mock) - # This needs to match the window size + # Create messages that exceed the token limit + long_content = ( + "This is a very long message that will exceed our token limit " * 50 + ) messages_raw = [ - json.dumps({"role": "user", "content": "Message 1"}), - json.dumps({"role": "assistant", "content": "Message 2"}), - json.dumps({"role": "user", "content": "Message 3"}), - json.dumps({"role": "assistant", "content": "Message 4"}), + json.dumps({"role": "user", "content": long_content}), + json.dumps({"role": "assistant", "content": long_content}), + json.dumps({"role": "user", "content": long_content}), + json.dumps({"role": "assistant", "content": "Short recent message"}), ] pipeline_mock.lrange = AsyncMock(return_value=messages_raw) @@ -113,7 +116,7 @@ async def test_summarize_session( pipeline_mock.hmset = MagicMock(return_value=True) pipeline_mock.ltrim = MagicMock(return_value=True) pipeline_mock.execute = AsyncMock(return_value=True) - pipeline_mock.llen = AsyncMock(return_value=window_size) + pipeline_mock.llen = AsyncMock(return_value=4) mock_summarization.return_value = ("New summary", 300) @@ -131,28 +134,33 @@ async def test_summarize_session( await summarize_session( session_id, model, - window_size, + max_context_tokens, ) assert pipeline_mock.lrange.call_count == 1 assert pipeline_mock.lrange.call_args[0][0] == Keys.messages_key(session_id) assert pipeline_mock.lrange.call_args[0][1] == 0 - assert pipeline_mock.lrange.call_args[0][2] == window_size - 1 + assert pipeline_mock.lrange.call_args[0][2] == -1 # Get all messages assert pipeline_mock.hgetall.call_count == 1 assert pipeline_mock.hgetall.call_args[0][0] == Keys.metadata_key(session_id) assert pipeline_mock.hmset.call_count == 1 assert pipeline_mock.hmset.call_args[0][0] == Keys.metadata_key(session_id) - assert pipeline_mock.hmset.call_args.kwargs["mapping"] == { - "context": "New summary", - "tokens": "320", - } + # Verify that hmset was called with the new summary + hmset_mapping = pipeline_mock.hmset.call_args.kwargs["mapping"] + assert hmset_mapping["context"] == "New summary" + # Token count will vary based on the actual messages passed for summarization + assert "tokens" in hmset_mapping + assert ( + int(hmset_mapping["tokens"]) > 300 + ) # Should include summarization tokens plus message tokens assert pipeline_mock.ltrim.call_count == 1 assert pipeline_mock.ltrim.call_args[0][0] == Keys.messages_key(session_id) - assert pipeline_mock.ltrim.call_args[0][1] == 0 - assert pipeline_mock.ltrim.call_args[0][2] == window_size - 1 + # New token-based approach keeps recent messages + assert pipeline_mock.ltrim.call_args[0][1] == -1 # Keep last message + assert pipeline_mock.ltrim.call_args[0][2] == -1 assert pipeline_mock.execute.call_count == 1 @@ -160,12 +168,8 @@ async def test_summarize_session( assert mock_summarization.call_args[0][0] == model assert mock_summarization.call_args[0][1] == mock_openai_client assert mock_summarization.call_args[0][2] == "Previous summary" - assert mock_summarization.call_args[0][3] == [ - "user: Message 1", - "assistant: Message 2", - "user: Message 3", - "assistant: Message 4", - ] + # Verify that some messages were passed for summarization + assert len(mock_summarization.call_args[0][3]) > 0 @pytest.mark.asyncio @patch("agent_memory_server.summarization._incremental_summary") @@ -175,18 +179,24 @@ async def test_handle_summarization_no_messages( """Test summarize_session when no messages need summarization""" session_id = "test-session" model = "gpt-3.5-turbo" - window_size = 12 + max_context_tokens = 10000 # High limit so no summarization needed pipeline_mock = MagicMock() # pipeline is not a coroutine pipeline_mock.__aenter__ = AsyncMock(return_value=pipeline_mock) pipeline_mock.watch = AsyncMock() mock_async_redis_client.pipeline = MagicMock(return_value=pipeline_mock) - pipeline_mock.llen = AsyncMock(return_value=0) - pipeline_mock.lrange = AsyncMock(return_value=[]) + # Set up short messages that won't exceed token limit + short_messages = [ + json.dumps({"role": "user", "content": "Short message 1"}), + json.dumps({"role": "assistant", "content": "Short response 1"}), + ] + + pipeline_mock.llen = AsyncMock(return_value=2) + pipeline_mock.lrange = AsyncMock(return_value=short_messages) pipeline_mock.hgetall = AsyncMock(return_value={}) pipeline_mock.hmset = AsyncMock(return_value=True) - pipeline_mock.lpop = AsyncMock(return_value=True) + pipeline_mock.ltrim = AsyncMock(return_value=True) pipeline_mock.execute = AsyncMock(return_value=True) with patch( @@ -196,12 +206,15 @@ async def test_handle_summarization_no_messages( await summarize_session( session_id, model, - window_size, + max_context_tokens, ) + # Should not summarize because messages are under token limit assert mock_summarization.call_count == 0 - assert pipeline_mock.lrange.call_count == 0 - assert pipeline_mock.hgetall.call_count == 0 + # But should still check messages and metadata + assert pipeline_mock.lrange.call_count == 1 + assert pipeline_mock.hgetall.call_count == 1 + # Should not update anything since no summarization needed assert pipeline_mock.hmset.call_count == 0 - assert pipeline_mock.lpop.call_count == 0 + assert pipeline_mock.ltrim.call_count == 0 assert pipeline_mock.execute.call_count == 0