Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 136 additions & 3 deletions agent_memory_server/auth.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import secrets
import threading
import time
from datetime import UTC, datetime
from typing import Any

import bcrypt
import httpx
import structlog
from fastapi import Depends, HTTPException, status
Expand All @@ -11,6 +13,8 @@
from pydantic import BaseModel

from agent_memory_server.config import settings
from agent_memory_server.utils.keys import Keys
from agent_memory_server.utils.redis import get_redis_conn


logger = structlog.get_logger()
Expand All @@ -27,6 +31,15 @@ class UserInfo(BaseModel):
roles: list[str] | None = None


class TokenInfo(BaseModel):
"""Token information stored in Redis."""

description: str
created_at: datetime
expires_at: datetime | None = None
token_hash: str


class JWKSCache:
def __init__(self, cache_duration: int = 3600):
self._cache: dict[str, Any] = {}
Expand Down Expand Up @@ -245,10 +258,98 @@ def verify_jwt(token: str) -> UserInfo:
) from e


def generate_token() -> str:
"""Generate a secure random token."""
return secrets.token_urlsafe(32)


def hash_token(token: str) -> str:
"""Hash a token using bcrypt."""
return bcrypt.hashpw(token.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")


def verify_token_hash(token: str, token_hash: str) -> bool:
"""Verify a token against its hash."""
try:
return bcrypt.checkpw(token.encode("utf-8"), token_hash.encode("utf-8"))
except Exception as e:
logger.warning("Token hash verification failed", error=str(e))
return False


async def verify_token(token: str) -> UserInfo:
"""Verify a token and return user info."""
try:
redis = await get_redis_conn()

# Get all auth tokens and check each one
# This is not the most efficient approach, but it works for now
# In a production system, you might want to store a mapping of token prefixes
pattern = Keys.auth_token_key("*")
token_keys = []

async for key in redis.scan_iter(pattern):
token_keys.append(key)

for key in token_keys:
token_data = await redis.get(key)
if not token_data:
continue

try:
token_info = TokenInfo.model_validate_json(token_data)

# Check if token matches
if verify_token_hash(token, token_info.token_hash):
# Check if token is expired
if (
token_info.expires_at
and datetime.now(UTC) > token_info.expires_at
):
logger.warning("Token has expired")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
)

# Return user info for valid token
return UserInfo(
sub="token-user",
aud="token-auth",
scope="admin",
roles=["admin"],
exp=int(token_info.expires_at.timestamp())
if token_info.expires_at
else None,
iat=int(token_info.created_at.timestamp()),
)

except HTTPException:
# Re-raise HTTP exceptions (like token expired)
raise
except Exception as e:
logger.warning("Error processing token", error=str(e))
continue

# If no token matched, authentication failed
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)

except HTTPException:
raise
except Exception as e:
logger.error("Unexpected error during token verification", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error during authentication",
) from e


def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(oauth2_scheme),
) -> UserInfo:
if settings.disable_auth:
if settings.disable_auth or settings.auth_mode == "disabled":
logger.debug("Authentication disabled, returning default user")
return UserInfo(
sub="local-dev-user", aud="local-dev", scope="admin", roles=["admin"]
Expand All @@ -268,6 +369,14 @@ def get_current_user(
headers={"WWW-Authenticate": "Bearer"},
)

# Determine authentication mode
if settings.auth_mode == "token" or settings.token_auth_enabled:
import asyncio

return asyncio.run(verify_token(credentials.credentials))
if settings.auth_mode == "oauth2":
return verify_jwt(credentials.credentials)
# Default to OAuth2 for backward compatibility
return verify_jwt(credentials.credentials)


Expand Down Expand Up @@ -304,18 +413,42 @@ def role_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:


def verify_auth_config():
if settings.disable_auth:
if settings.disable_auth or settings.auth_mode == "disabled":
logger.warning("Authentication is DISABLED - suitable for development only")
return

if settings.auth_mode == "token" or settings.token_auth_enabled:
logger.info("Token authentication configured")
return

if settings.auth_mode == "oauth2":
if not settings.oauth2_issuer_url:
raise ValueError(
"OAUTH2_ISSUER_URL must be set when OAuth2 authentication is enabled"
)

if not settings.oauth2_audience:
logger.warning(
"OAUTH2_AUDIENCE not set - audience validation will be skipped"
)

logger.info(
"OAuth2 authentication configured",
issuer=settings.oauth2_issuer_url,
audience=settings.oauth2_audience or "not-set",
algorithms=settings.oauth2_algorithms,
)
return

# Default to OAuth2 for backward compatibility
if not settings.oauth2_issuer_url:
raise ValueError("OAUTH2_ISSUER_URL must be set when authentication is enabled")

if not settings.oauth2_audience:
logger.warning("OAUTH2_AUDIENCE not set - audience validation will be skipped")

logger.info(
"OAuth2 authentication configured",
"OAuth2 authentication configured (default)",
issuer=settings.oauth2_issuer_url,
audience=settings.oauth2_audience or "not-set",
algorithms=settings.oauth2_algorithms,
Expand Down
Loading