diff --git a/.github/workflows/install_and_test.sh b/.github/workflows/install_and_test.sh index c90027389c..85cb07cb8a 100755 --- a/.github/workflows/install_and_test.sh +++ b/.github/workflows/install_and_test.sh @@ -40,7 +40,7 @@ cd ${TESTDIR} # install, run tests pip install ${PKG} # Redis tests -pytest -m 'not onlycluster' --ignore=tests/test_scenario +pytest -m 'not onlycluster' --ignore=tests/test_scenario --ignore=tests/test_asyncio/test_scenario # RedisCluster tests CLUSTER_URL="redis://localhost:16379/0" CLUSTER_SSL_URL="rediss://localhost:27379/0" diff --git a/dev_requirements.txt b/dev_requirements.txt index 848d6207c4..e61f37f101 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -14,3 +14,4 @@ uvloop vulture>=2.3.0 numpy>=1.24.0 redis-entraid==1.0.0 +pybreaker>=1.4.0 diff --git a/pyproject.toml b/pyproject.toml index ee061953c5..198ac71a7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,9 @@ ocsp = [ jwt = [ "PyJWT>=2.9.0", ] +circuit_breaker = [ + "pybreaker>=1.4.0" +] [project.urls] Changes = "https://github.com/redis/redis-py/releases" diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 3defeceead..ab5a3ac0bd 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -1239,6 +1239,7 @@ async def run( *, exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, poll_timeout: float = 1.0, + pubsub=None, ) -> None: """Process pub/sub messages using registered callbacks. @@ -1263,9 +1264,14 @@ async def run( await self.connect() while True: try: - await self.get_message( - ignore_subscribe_messages=True, timeout=poll_timeout - ) + if pubsub is None: + await self.get_message( + ignore_subscribe_messages=True, timeout=poll_timeout + ) + else: + await pubsub.get_message( + ignore_subscribe_messages=True, timeout=poll_timeout + ) except asyncio.CancelledError: raise except BaseException as e: diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 4e0e06517d..225fd3b79f 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -409,6 +409,7 @@ def __init__( else: self._event_dispatcher = event_dispatcher + self.startup_nodes = startup_nodes self.nodes_manager = NodesManager( startup_nodes, require_full_coverage, @@ -2253,7 +2254,10 @@ async def _reinitialize_on_error(self, error): await self._pipe.cluster_client.nodes_manager.initialize() self.reinitialize_counter = 0 else: - self._pipe.cluster_client.nodes_manager.update_moved_exception(error) + if isinstance(error, AskError): + self._pipe.cluster_client.nodes_manager.update_moved_exception( + error + ) self._executing = False diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index e3eb3bd9f1..c79ed690d9 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -213,6 +213,7 @@ def __init__( self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] self._buffer_cutoff = 6000 self._re_auth_token: Optional[TokenInterface] = None + self._should_reconnect = False try: p = int(protocol) @@ -343,6 +344,12 @@ async def connect_check_health( if task and inspect.isawaitable(task): await task + def mark_for_reconnect(self): + self._should_reconnect = True + + def should_reconnect(self): + return self._should_reconnect + @abstractmethod async def _connect(self): pass @@ -1240,6 +1247,9 @@ async def release(self, connection: AbstractConnection): # Connections should always be returned to the correct pool, # not doing so is an error that will cause an exception here. self._in_use_connections.remove(connection) + if connection.should_reconnect(): + await connection.disconnect() + self._available_connections.append(connection) await self._event_dispatcher.dispatch_async( AsyncAfterConnectionReleasedEvent(connection) @@ -1267,6 +1277,14 @@ async def disconnect(self, inuse_connections: bool = True): if exc: raise exc + async def update_active_connections_for_reconnect(self): + """ + Mark all active connections for reconnect. + """ + async with self._lock: + for conn in self._in_use_connections: + conn.mark_for_reconnect() + async def aclose(self) -> None: """Close the pool, disconnecting all connections""" await self.disconnect() diff --git a/redis/asyncio/http/__init__.py b/redis/asyncio/http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/asyncio/http/http_client.py b/redis/asyncio/http/http_client.py new file mode 100644 index 0000000000..688b33b2e3 --- /dev/null +++ b/redis/asyncio/http/http_client.py @@ -0,0 +1,265 @@ +import asyncio +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Mapping, Optional, Union + +from redis.http.http_client import HttpClient, HttpResponse + +DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)" +DEFAULT_TIMEOUT = 30.0 +RETRY_STATUS_CODES = {429, 500, 502, 503, 504} + + +class AsyncHTTPClient(ABC): + @abstractmethod + async def get( + self, + path: str, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP GET request.""" + pass + + @abstractmethod + async def delete( + self, + path: str, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP DELETE request.""" + pass + + @abstractmethod + async def post( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP POST request.""" + pass + + @abstractmethod + async def put( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP PUT request.""" + pass + + @abstractmethod + async def patch( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + """ + Invoke HTTP PATCH request.""" + pass + + @abstractmethod + async def request( + self, + method: str, + path: str, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + body: Optional[Union[bytes, str]] = None, + timeout: Optional[float] = None, + ) -> HttpResponse: + """ + Invoke HTTP request with given method.""" + pass + + +class AsyncHTTPClientWrapper(AsyncHTTPClient): + """ + An async wrapper around sync HTTP client with thread pool execution. + """ + + def __init__(self, client: HttpClient, max_workers: int = 10) -> None: + """ + Initialize a new HTTP client instance. + + Args: + client: Sync HTTP client instance. + max_workers: Maximum number of concurrent requests. + + The client supports both regular HTTPS with server verification and mutual TLS + authentication. For server verification, provide CA certificate information via + ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client + certificate and key via client_cert_file and client_key_file. + """ + self.client = client + self._executor = ThreadPoolExecutor(max_workers=max_workers) + + async def get( + self, + path: str, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, self.client.get, path, params, headers, timeout, expect_json + ) + + async def delete( + self, + path: str, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.delete, + path, + params, + headers, + timeout, + expect_json, + ) + + async def post( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.post, + path, + json_body, + data, + params, + headers, + timeout, + expect_json, + ) + + async def put( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.put, + path, + json_body, + data, + params, + headers, + timeout, + expect_json, + ) + + async def patch( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.patch, + path, + json_body, + data, + params, + headers, + timeout, + expect_json, + ) + + async def request( + self, + method: str, + path: str, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + body: Optional[Union[bytes, str]] = None, + timeout: Optional[float] = None, + ) -> HttpResponse: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self._executor, + self.client.request, + method, + path, + params, + headers, + body, + timeout, + ) diff --git a/redis/asyncio/multidb/__init__.py b/redis/asyncio/multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py new file mode 100644 index 0000000000..6bea588196 --- /dev/null +++ b/redis/asyncio/multidb/client.py @@ -0,0 +1,528 @@ +import asyncio +import logging +from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Union + +from redis.asyncio.client import PubSubHandler +from redis.asyncio.multidb.command_executor import DefaultCommandExecutor +from redis.asyncio.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig +from redis.asyncio.multidb.database import AsyncDatabase, Databases +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import HealthCheck, HealthCheckPolicy +from redis.background import BackgroundScheduler +from redis.commands import AsyncCoreCommands, AsyncRedisModuleCommands +from redis.multidb.circuit import CircuitBreaker +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException +from redis.typing import ChannelT, EncodableT, KeyT + +logger = logging.getLogger(__name__) + + +class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): + """ + Client that operates on multiple logical Redis databases. + Should be used in Active-Active database setups. + """ + + def __init__(self, config: MultiDbConfig): + self._databases = config.databases() + self._health_checks = config.default_health_checks() + + if config.health_checks is not None: + self._health_checks.extend(config.health_checks) + + self._health_check_interval = config.health_check_interval + self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value( + config.health_check_probes, config.health_check_delay + ) + self._failure_detectors = config.default_failure_detectors() + + if config.failure_detectors is not None: + self._failure_detectors.extend(config.failure_detectors) + + self._failover_strategy = ( + config.default_failover_strategy() + if config.failover_strategy is None + else config.failover_strategy + ) + self._failover_strategy.set_databases(self._databases) + self._auto_fallback_interval = config.auto_fallback_interval + self._event_dispatcher = config.event_dispatcher + self._command_retry = config.command_retry + self._command_retry.update_supported_errors([ConnectionRefusedError]) + self.command_executor = DefaultCommandExecutor( + failure_detectors=self._failure_detectors, + databases=self._databases, + command_retry=self._command_retry, + failover_strategy=self._failover_strategy, + failover_attempts=config.failover_attempts, + failover_delay=config.failover_delay, + event_dispatcher=self._event_dispatcher, + auto_fallback_interval=self._auto_fallback_interval, + ) + self.initialized = False + self._hc_lock = asyncio.Lock() + self._bg_scheduler = BackgroundScheduler() + self._config = config + self._recurring_hc_task = None + self._hc_tasks = [] + self._half_open_state_task = None + + async def __aenter__(self: "MultiDBClient") -> "MultiDBClient": + if not self.initialized: + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if self._recurring_hc_task: + self._recurring_hc_task.cancel() + if self._half_open_state_task: + self._half_open_state_task.cancel() + for hc_task in self._hc_tasks: + hc_task.cancel() + + async def initialize(self): + """ + Perform initialization of databases to define their initial state. + """ + + async def raise_exception_on_failed_hc(error): + raise error + + # Initial databases check to define initial state + await self._check_databases_health(on_error=raise_exception_on_failed_hc) + + # Starts recurring health checks on the background. + self._recurring_hc_task = asyncio.create_task( + self._bg_scheduler.run_recurring_async( + self._health_check_interval, + self._check_databases_health, + ) + ) + + is_active_db_found = False + + for database, weight in self._databases: + # Set on state changed callback for each circuit. + database.circuit.on_state_changed(self._on_circuit_state_change_callback) + + # Set states according to a weights and circuit state + if database.circuit.state == CBState.CLOSED and not is_active_db_found: + await self.command_executor.set_active_database(database) + is_active_db_found = True + + if not is_active_db_found: + raise NoValidDatabaseException( + "Initial connection failed - no active database found" + ) + + self.initialized = True + + def get_databases(self) -> Databases: + """ + Returns a sorted (by weight) list of all databases. + """ + return self._databases + + async def set_active_database(self, database: AsyncDatabase) -> None: + """ + Promote one of the existing databases to become an active. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError("Given database is not a member of database list") + + await self._check_db_health(database) + + if database.circuit.state == CBState.CLOSED: + highest_weighted_db, _ = self._databases.get_top_n(1)[0] + await self.command_executor.set_active_database(database) + return + + raise NoValidDatabaseException( + "Cannot set active database, database is unhealthy" + ) + + async def add_database(self, database: AsyncDatabase): + """ + Adds a new database to the database list. + """ + for existing_db, _ in self._databases: + if existing_db == database: + raise ValueError("Given database already exists") + + await self._check_db_health(database) + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.add(database, database.weight) + await self._change_active_database(database, highest_weighted_db) + + async def _change_active_database( + self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase + ): + if ( + new_database.weight > highest_weight_database.weight + and new_database.circuit.state == CBState.CLOSED + ): + await self.command_executor.set_active_database(new_database) + + async def remove_database(self, database: AsyncDatabase): + """ + Removes a database from the database list. + """ + weight = self._databases.remove(database) + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + + if ( + highest_weight <= weight + and highest_weighted_db.circuit.state == CBState.CLOSED + ): + await self.command_executor.set_active_database(highest_weighted_db) + + async def update_database_weight(self, database: AsyncDatabase, weight: float): + """ + Updates a database from the database list. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError("Given database is not a member of database list") + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.update_weight(database, weight) + database.weight = weight + await self._change_active_database(database, highest_weighted_db) + + def add_failure_detector(self, failure_detector: AsyncFailureDetector): + """ + Adds a new failure detector to the database. + """ + self._failure_detectors.append(failure_detector) + + async def add_health_check(self, healthcheck: HealthCheck): + """ + Adds a new health check to the database. + """ + async with self._hc_lock: + self._health_checks.append(healthcheck) + + async def execute_command(self, *args, **options): + """ + Executes a single command and return its result. + """ + if not self.initialized: + await self.initialize() + + return await self.command_executor.execute_command(*args, **options) + + def pipeline(self): + """ + Enters into pipeline mode of the client. + """ + return Pipeline(self) + + async def transaction( + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, + ): + """ + Executes callable as transaction. + """ + if not self.initialized: + await self.initialize() + + return await self.command_executor.execute_transaction( + func, + *watches, + shard_hint=shard_hint, + value_from_callable=value_from_callable, + watch_delay=watch_delay, + ) + + async def pubsub(self, **kwargs): + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + if not self.initialized: + await self.initialize() + + return PubSub(self, **kwargs) + + async def _check_databases_health( + self, + on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, + ): + """ + Runs health checks as a recurring task. + Runs health checks against all databases. + """ + try: + self._hc_tasks = [ + asyncio.create_task(self._check_db_health(database)) + for database, _ in self._databases + ] + results = await asyncio.wait_for( + asyncio.gather( + *self._hc_tasks, + return_exceptions=True, + ), + timeout=self._health_check_interval, + ) + except asyncio.TimeoutError: + raise asyncio.TimeoutError( + "Health check execution exceeds health_check_interval" + ) + + for result in results: + if isinstance(result, UnhealthyDatabaseException): + unhealthy_db = result.database + unhealthy_db.circuit.state = CBState.OPEN + + logger.exception( + "Health check failed, due to exception", + exc_info=result.original_exception, + ) + + if on_error: + on_error(result.original_exception) + + async def _check_db_health(self, database: AsyncDatabase) -> bool: + """ + Runs health checks on the given database until first failure. + """ + # Health check will setup circuit state + is_healthy = await self._health_check_policy.execute( + self._health_checks, database + ) + + if not is_healthy: + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + return is_healthy + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED + + return is_healthy + + def _on_circuit_state_change_callback( + self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState + ): + loop = asyncio.get_running_loop() + + if new_state == CBState.HALF_OPEN: + self._half_open_state_task = asyncio.create_task( + self._check_db_health(circuit.database) + ) + return + + if old_state == CBState.CLOSED and new_state == CBState.OPEN: + loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) + + async def aclose(self): + if self.command_executor.active_database: + await self.command_executor.active_database.client.aclose() + + +def _half_open_circuit(circuit: CircuitBreaker): + circuit.state = CBState.HALF_OPEN + + +class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands): + """ + Pipeline implementation for multiple logical Redis databases. + """ + + def __init__(self, client: MultiDBClient): + self._command_stack = [] + self._client = client + + async def __aenter__(self: "Pipeline") -> "Pipeline": + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + await self._client.__aexit__(exc_type, exc_value, traceback) + + def __await__(self): + return self._async_self().__await__() + + async def _async_self(self): + return self + + def __len__(self) -> int: + return len(self._command_stack) + + def __bool__(self) -> bool: + """Pipeline instances should always evaluate to True""" + return True + + async def reset(self) -> None: + self._command_stack = [] + + async def aclose(self) -> None: + """Close the pipeline""" + await self.reset() + + def pipeline_execute_command(self, *args, **options) -> "Pipeline": + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self._command_stack.append((args, options)) + return self + + def execute_command(self, *args, **kwargs): + """Adds a command to the stack""" + return self.pipeline_execute_command(*args, **kwargs) + + async def execute(self) -> List[Any]: + """Execute all the commands in the current pipeline""" + if not self._client.initialized: + await self._client.initialize() + + try: + return await self._client.command_executor.execute_pipeline( + tuple(self._command_stack) + ) + finally: + await self.reset() + + +class PubSub: + """ + PubSub object for multi database client. + """ + + def __init__(self, client: MultiDBClient, **kwargs): + """Initialize the PubSub object for a multi-database client. + + Args: + client: MultiDBClient instance to use for pub/sub operations + **kwargs: Additional keyword arguments to pass to the underlying pubsub implementation + """ + + self._client = client + self._client.command_executor.pubsub(**kwargs) + + async def __aenter__(self) -> "PubSub": + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + await self.aclose() + + async def aclose(self): + return await self._client.command_executor.execute_pubsub_method("aclose") + + @property + def subscribed(self) -> bool: + return self._client.command_executor.active_pubsub.subscribed + + async def execute_command(self, *args: EncodableT): + return await self._client.command_executor.execute_pubsub_method( + "execute_command", *args + ) + + async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + return await self._client.command_executor.execute_pubsub_method( + "psubscribe", *args, **kwargs + ) + + async def punsubscribe(self, *args: ChannelT): + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + return await self._client.command_executor.execute_pubsub_method( + "punsubscribe", *args + ) + + async def subscribe(self, *args: ChannelT, **kwargs: Callable): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + return await self._client.command_executor.execute_pubsub_method( + "subscribe", *args, **kwargs + ) + + async def unsubscribe(self, *args): + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + return await self._client.command_executor.execute_pubsub_method( + "unsubscribe", *args + ) + + async def get_message( + self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number or None to wait indefinitely. + """ + return await self._client.command_executor.execute_pubsub_method( + "get_message", + ignore_subscribe_messages=ignore_subscribe_messages, + timeout=timeout, + ) + + async def run( + self, + *, + exception_handler=None, + poll_timeout: float = 1.0, + ) -> None: + """Process pub/sub messages using registered callbacks. + + This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in + redis-py, but it is a coroutine. To launch it as a separate task, use + ``asyncio.create_task``: + + >>> task = asyncio.create_task(pubsub.run()) + + To shut it down, use asyncio cancellation: + + >>> task.cancel() + >>> await task + """ + return await self._client.command_executor.execute_pubsub_run( + sleep_time=poll_timeout, exception_handler=exception_handler, pubsub=self + ) diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py new file mode 100644 index 0000000000..c09b8b9969 --- /dev/null +++ b/redis/asyncio/multidb/command_executor.py @@ -0,0 +1,339 @@ +from abc import abstractmethod +from asyncio import iscoroutinefunction +from datetime import datetime +from typing import Any, Awaitable, Callable, List, Optional, Union + +from redis.asyncio import RedisCluster +from redis.asyncio.client import Pipeline, PubSub +from redis.asyncio.multidb.database import AsyncDatabase, Database, Databases +from redis.asyncio.multidb.event import ( + AsyncActiveDatabaseChanged, + CloseConnectionOnActiveDatabaseChanged, + RegisterCommandFailure, + ResubscribeOnActiveDatabaseChanged, +) +from redis.asyncio.multidb.failover import ( + DEFAULT_FAILOVER_ATTEMPTS, + DEFAULT_FAILOVER_DELAY, + AsyncFailoverStrategy, + DefaultFailoverStrategyExecutor, + FailoverStrategyExecutor, +) +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.retry import Retry +from redis.event import AsyncOnCommandsFailEvent, EventDispatcherInterface +from redis.multidb.circuit import State as CBState +from redis.multidb.command_executor import BaseCommandExecutor, CommandExecutor +from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.typing import KeyT + + +class AsyncCommandExecutor(CommandExecutor): + @property + @abstractmethod + def databases(self) -> Databases: + """Returns a list of databases.""" + pass + + @property + @abstractmethod + def failure_detectors(self) -> List[AsyncFailureDetector]: + """Returns a list of failure detectors.""" + pass + + @abstractmethod + def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None: + """Adds a new failure detector to the list of failure detectors.""" + pass + + @property + @abstractmethod + def active_database(self) -> Optional[AsyncDatabase]: + """Returns currently active database.""" + pass + + @abstractmethod + async def set_active_database(self, database: AsyncDatabase) -> None: + """Sets the currently active database.""" + pass + + @property + @abstractmethod + def active_pubsub(self) -> Optional[PubSub]: + """Returns currently active pubsub.""" + pass + + @active_pubsub.setter + @abstractmethod + def active_pubsub(self, pubsub: PubSub) -> None: + """Sets currently active pubsub.""" + pass + + @property + @abstractmethod + def failover_strategy_executor(self) -> FailoverStrategyExecutor: + """Returns failover strategy executor.""" + pass + + @property + @abstractmethod + def command_retry(self) -> Retry: + """Returns command retry object.""" + pass + + @abstractmethod + async def pubsub(self, **kwargs): + """Initializes a PubSub object on a currently active database""" + pass + + @abstractmethod + async def execute_command(self, *args, **options): + """Executes a command and returns the result.""" + pass + + @abstractmethod + async def execute_pipeline(self, command_stack: tuple): + """Executes a stack of commands in pipeline.""" + pass + + @abstractmethod + async def execute_transaction( + self, transaction: Callable[[Pipeline], None], *watches, **options + ): + """Executes a transaction block wrapped in callback.""" + pass + + @abstractmethod + async def execute_pubsub_method(self, method_name: str, *args, **kwargs): + """Executes a given method on active pub/sub.""" + pass + + @abstractmethod + async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + """Executes pub/sub run in a thread.""" + pass + + +class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor): + def __init__( + self, + failure_detectors: List[AsyncFailureDetector], + databases: Databases, + command_retry: Retry, + failover_strategy: AsyncFailoverStrategy, + event_dispatcher: EventDispatcherInterface, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + """ + Initialize the DefaultCommandExecutor instance. + + Args: + failure_detectors: List of failure detector instances to monitor database health + databases: Collection of available databases to execute commands on + command_retry: Retry policy for failed command execution + failover_strategy: Strategy for handling database failover + event_dispatcher: Interface for dispatching events + failover_attempts: Number of failover attempts + failover_delay: Delay between failover attempts + auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database + """ + super().__init__(auto_fallback_interval) + + for fd in failure_detectors: + fd.set_command_executor(command_executor=self) + + self._databases = databases + self._failure_detectors = failure_detectors + self._command_retry = command_retry + self._failover_strategy_executor = DefaultFailoverStrategyExecutor( + failover_strategy, failover_attempts, failover_delay + ) + self._event_dispatcher = event_dispatcher + self._active_database: Optional[Database] = None + self._active_pubsub: Optional[PubSub] = None + self._active_pubsub_kwargs = {} + self._setup_event_dispatcher() + self._schedule_next_fallback() + + @property + def databases(self) -> Databases: + return self._databases + + @property + def failure_detectors(self) -> List[AsyncFailureDetector]: + return self._failure_detectors + + def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None: + self._failure_detectors.append(failure_detector) + + @property + def active_database(self) -> Optional[AsyncDatabase]: + return self._active_database + + async def set_active_database(self, database: AsyncDatabase) -> None: + old_active = self._active_database + self._active_database = database + + if old_active is not None and old_active is not database: + await self._event_dispatcher.dispatch_async( + AsyncActiveDatabaseChanged( + old_active, + self._active_database, + self, + **self._active_pubsub_kwargs, + ) + ) + + @property + def active_pubsub(self) -> Optional[PubSub]: + return self._active_pubsub + + @active_pubsub.setter + def active_pubsub(self, pubsub: PubSub) -> None: + self._active_pubsub = pubsub + + @property + def failover_strategy_executor(self) -> FailoverStrategyExecutor: + return self._failover_strategy_executor + + @property + def command_retry(self) -> Retry: + return self._command_retry + + def pubsub(self, **kwargs): + if self._active_pubsub is None: + if isinstance(self._active_database.client, RedisCluster): + raise ValueError("PubSub is not supported for RedisCluster") + + self._active_pubsub = self._active_database.client.pubsub(**kwargs) + self._active_pubsub_kwargs = kwargs + + async def execute_command(self, *args, **options): + async def callback(): + response = await self._active_database.client.execute_command( + *args, **options + ) + await self._register_command_execution(args) + return response + + return await self._execute_with_failure_detection(callback, args) + + async def execute_pipeline(self, command_stack: tuple): + async def callback(): + async with self._active_database.client.pipeline() as pipe: + for command, options in command_stack: + pipe.execute_command(*command, **options) + + response = await pipe.execute() + await self._register_command_execution(command_stack) + return response + + return await self._execute_with_failure_detection(callback, command_stack) + + async def execute_transaction( + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, + ): + async def callback(): + response = await self._active_database.client.transaction( + func, + *watches, + shard_hint=shard_hint, + value_from_callable=value_from_callable, + watch_delay=watch_delay, + ) + await self._register_command_execution(()) + return response + + return await self._execute_with_failure_detection(callback) + + async def execute_pubsub_method(self, method_name: str, *args, **kwargs): + async def callback(): + method = getattr(self.active_pubsub, method_name) + if iscoroutinefunction(method): + response = await method(*args, **kwargs) + else: + response = method(*args, **kwargs) + + await self._register_command_execution(args) + return response + + return await self._execute_with_failure_detection(callback, *args) + + async def execute_pubsub_run( + self, sleep_time: float, exception_handler=None, pubsub=None + ) -> Any: + async def callback(): + return await self._active_pubsub.run( + poll_timeout=sleep_time, + exception_handler=exception_handler, + pubsub=pubsub, + ) + + return await self._execute_with_failure_detection(callback) + + async def _execute_with_failure_detection( + self, callback: Callable, cmds: tuple = () + ): + """ + Execute a commands execution callback with failure detection. + """ + + async def wrapper(): + # On each retry we need to check active database as it might change. + await self._check_active_database() + return await callback() + + return await self._command_retry.call_with_retry( + lambda: wrapper(), + lambda error: self._on_command_fail(error, *cmds), + ) + + async def _check_active_database(self): + """ + Checks if active a database needs to be updated. + """ + if ( + self._active_database is None + or self._active_database.circuit.state != CBState.CLOSED + or ( + self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL + and self._next_fallback_attempt <= datetime.now() + ) + ): + await self.set_active_database( + await self._failover_strategy_executor.execute() + ) + self._schedule_next_fallback() + + async def _on_command_fail(self, error, *args): + await self._event_dispatcher.dispatch_async( + AsyncOnCommandsFailEvent(args, error) + ) + + async def _register_command_execution(self, cmd: tuple): + for detector in self._failure_detectors: + await detector.register_command_execution(cmd) + + def _setup_event_dispatcher(self): + """ + Registers necessary listeners. + """ + failure_listener = RegisterCommandFailure(self._failure_detectors) + resubscribe_listener = ResubscribeOnActiveDatabaseChanged() + close_connection_listener = CloseConnectionOnActiveDatabaseChanged() + self._event_dispatcher.register_listeners( + { + AsyncOnCommandsFailEvent: [failure_listener], + AsyncActiveDatabaseChanged: [ + close_connection_listener, + resubscribe_listener, + ], + } + ) diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py new file mode 100644 index 0000000000..71f69ad133 --- /dev/null +++ b/redis/asyncio/multidb/config.py @@ -0,0 +1,210 @@ +from dataclasses import dataclass, field +from typing import List, Optional, Type, Union + +import pybreaker + +from redis.asyncio import ConnectionPool, Redis, RedisCluster +from redis.asyncio.multidb.database import Database, Databases +from redis.asyncio.multidb.failover import ( + DEFAULT_FAILOVER_ATTEMPTS, + DEFAULT_FAILOVER_DELAY, + AsyncFailoverStrategy, + WeightBasedFailoverStrategy, +) +from redis.asyncio.multidb.failure_detector import ( + AsyncFailureDetector, + FailureDetectorAsyncWrapper, +) +from redis.asyncio.multidb.healthcheck import ( + DEFAULT_HEALTH_CHECK_DELAY, + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_HEALTH_CHECK_POLICY, + DEFAULT_HEALTH_CHECK_PROBES, + EchoHealthCheck, + HealthCheck, + HealthCheckPolicies, +) +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialWithJitterBackoff, NoBackoff +from redis.data_structure import WeightedList +from redis.event import EventDispatcher, EventDispatcherInterface +from redis.multidb.circuit import ( + DEFAULT_GRACE_PERIOD, + CircuitBreaker, + PBCircuitBreakerAdapter, +) +from redis.multidb.failure_detector import ( + DEFAULT_FAILURE_RATE_THRESHOLD, + DEFAULT_FAILURES_DETECTION_WINDOW, + DEFAULT_MIN_NUM_FAILURES, + CommandFailureDetector, +) + +DEFAULT_AUTO_FALLBACK_INTERVAL = 120 + + +def default_event_dispatcher() -> EventDispatcherInterface: + return EventDispatcher() + + +@dataclass +class DatabaseConfig: + """ + Dataclass representing the configuration for a database connection. + + This class is used to store configuration settings for a database connection, + including client options, connection sourcing details, circuit breaker settings, + and cluster-specific properties. It provides a structure for defining these + attributes and allows for the creation of customized configurations for various + database setups. + + Attributes: + weight (float): Weight of the database to define the active one. + client_kwargs (dict): Additional parameters for the database client connection. + from_url (Optional[str]): Redis URL way of connecting to the database. + from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. + circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. + grace_period (float): Grace period after which we need to check if the circuit could be closed again. + health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used + on public Redis Enterprise endpoints. + + Methods: + default_circuit_breaker: + Generates and returns a default CircuitBreaker instance adapted for use. + """ + + weight: float = 1.0 + client_kwargs: dict = field(default_factory=dict) + from_url: Optional[str] = None + from_pool: Optional[ConnectionPool] = None + circuit: Optional[CircuitBreaker] = None + grace_period: float = DEFAULT_GRACE_PERIOD + health_check_url: Optional[str] = None + + def default_circuit_breaker(self) -> CircuitBreaker: + circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) + return PBCircuitBreakerAdapter(circuit_breaker) + + +@dataclass +class MultiDbConfig: + """ + Configuration class for managing multiple database connections in a resilient and fail-safe manner. + + Attributes: + databases_config: A list of database configurations. + client_class: The client class used to manage database connections. + command_retry: Retry strategy for executing database commands. + failure_detectors: Optional list of additional failure detectors for monitoring database failures. + min_num_failures: Minimal count of failures required for failover + failure_rate_threshold: Percentage of failures required for failover + failures_detection_window: Time interval for tracking database failures. + health_checks: Optional list of additional health checks performed on databases. + health_check_interval: Time interval for executing health checks. + health_check_probes: Number of attempts to evaluate the health of a database. + health_check_delay: Delay between health check attempts. + failover_strategy: Optional strategy for handling database failover scenarios. + failover_attempts: Number of retries allowed for failover operations. + failover_delay: Delay between failover attempts. + auto_fallback_interval: Time interval to trigger automatic fallback. + event_dispatcher: Interface for dispatching events related to database operations. + + Methods: + databases: + Retrieves a collection of database clients managed by weighted configurations. + Initializes database clients based on the provided configuration and removes + redundant retry objects for lower-level clients to rely on global retry logic. + + default_failure_detectors: + Returns the default list of failure detectors used to monitor database failures. + + default_health_checks: + Returns the default list of health checks used to monitor database health + with specific retry and backoff strategies. + + default_failover_strategy: + Provides the default failover strategy used for handling failover scenarios + with defined retry and backoff configurations. + """ + + databases_config: List[DatabaseConfig] + client_class: Type[Union[Redis, RedisCluster]] = Redis + command_retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ) + failure_detectors: Optional[List[AsyncFailureDetector]] = None + min_num_failures: int = DEFAULT_MIN_NUM_FAILURES + failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD + failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW + health_checks: Optional[List[HealthCheck]] = None + health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL + health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES + health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY + health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY + failover_strategy: Optional[AsyncFailoverStrategy] = None + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS + failover_delay: float = DEFAULT_FAILOVER_DELAY + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL + event_dispatcher: EventDispatcherInterface = field( + default_factory=default_event_dispatcher + ) + + def databases(self) -> Databases: + databases = WeightedList() + + for database_config in self.databases_config: + # The retry object is not used in the lower level clients, so we can safely remove it. + # We rely on command_retry in terms of global retries. + database_config.client_kwargs.update( + {"retry": Retry(retries=0, backoff=NoBackoff())} + ) + + if database_config.from_url: + client = self.client_class.from_url( + database_config.from_url, **database_config.client_kwargs + ) + elif database_config.from_pool: + database_config.from_pool.set_retry( + Retry(retries=0, backoff=NoBackoff()) + ) + client = self.client_class.from_pool( + connection_pool=database_config.from_pool + ) + else: + client = self.client_class(**database_config.client_kwargs) + + circuit = ( + database_config.default_circuit_breaker() + if database_config.circuit is None + else database_config.circuit + ) + databases.add( + Database( + client=client, + circuit=circuit, + weight=database_config.weight, + health_check_url=database_config.health_check_url, + ), + database_config.weight, + ) + + return databases + + def default_failure_detectors(self) -> List[AsyncFailureDetector]: + return [ + FailureDetectorAsyncWrapper( + CommandFailureDetector( + min_num_failures=self.min_num_failures, + failure_rate_threshold=self.failure_rate_threshold, + failure_detection_window=self.failures_detection_window, + ) + ), + ] + + def default_health_checks(self) -> List[HealthCheck]: + return [ + EchoHealthCheck(), + ] + + def default_failover_strategy(self) -> AsyncFailoverStrategy: + return WeightBasedFailoverStrategy() diff --git a/redis/asyncio/multidb/database.py b/redis/asyncio/multidb/database.py new file mode 100644 index 0000000000..ecf7a1b972 --- /dev/null +++ b/redis/asyncio/multidb/database.py @@ -0,0 +1,69 @@ +from abc import abstractmethod +from typing import Optional, Union + +from redis.asyncio import Redis, RedisCluster +from redis.data_structure import WeightedList +from redis.multidb.circuit import CircuitBreaker +from redis.multidb.database import AbstractDatabase, BaseDatabase +from redis.typing import Number + + +class AsyncDatabase(AbstractDatabase): + """Database with an underlying asynchronous redis client.""" + + @property + @abstractmethod + def client(self) -> Union[Redis, RedisCluster]: + """The underlying redis client.""" + pass + + @client.setter + @abstractmethod + def client(self, client: Union[Redis, RedisCluster]): + """Set the underlying redis client.""" + pass + + @property + @abstractmethod + def circuit(self) -> CircuitBreaker: + """Circuit breaker for the current database.""" + pass + + @circuit.setter + @abstractmethod + def circuit(self, circuit: CircuitBreaker): + """Set the circuit breaker for the current database.""" + pass + + +Databases = WeightedList[tuple[AsyncDatabase, Number]] + + +class Database(BaseDatabase, AsyncDatabase): + def __init__( + self, + client: Union[Redis, RedisCluster], + circuit: CircuitBreaker, + weight: float, + health_check_url: Optional[str] = None, + ): + self._client = client + self._cb = circuit + self._cb.database = self + super().__init__(weight, health_check_url) + + @property + def client(self) -> Union[Redis, RedisCluster]: + return self._client + + @client.setter + def client(self, client: Union[Redis, RedisCluster]): + self._client = client + + @property + def circuit(self) -> CircuitBreaker: + return self._cb + + @circuit.setter + def circuit(self, circuit: CircuitBreaker): + self._cb = circuit diff --git a/redis/asyncio/multidb/event.py b/redis/asyncio/multidb/event.py new file mode 100644 index 0000000000..ae25f1e37c --- /dev/null +++ b/redis/asyncio/multidb/event.py @@ -0,0 +1,84 @@ +from typing import List + +from redis.asyncio import Redis +from redis.asyncio.multidb.database import AsyncDatabase +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent + + +class AsyncActiveDatabaseChanged: + """ + Event fired when an async active database has been changed. + """ + + def __init__( + self, + old_database: AsyncDatabase, + new_database: AsyncDatabase, + command_executor, + **kwargs, + ): + self._old_database = old_database + self._new_database = new_database + self._command_executor = command_executor + self._kwargs = kwargs + + @property + def old_database(self) -> AsyncDatabase: + return self._old_database + + @property + def new_database(self) -> AsyncDatabase: + return self._new_database + + @property + def command_executor(self): + return self._command_executor + + @property + def kwargs(self): + return self._kwargs + + +class ResubscribeOnActiveDatabaseChanged(AsyncEventListenerInterface): + """ + Re-subscribe the currently active pub / sub to a new active database. + """ + + async def listen(self, event: AsyncActiveDatabaseChanged): + old_pubsub = event.command_executor.active_pubsub + + if old_pubsub is not None: + # Re-assign old channels and patterns so they will be automatically subscribed on connection. + new_pubsub = event.new_database.client.pubsub(**event.kwargs) + new_pubsub.channels = old_pubsub.channels + new_pubsub.patterns = old_pubsub.patterns + await new_pubsub.on_connect(None) + event.command_executor.active_pubsub = new_pubsub + await old_pubsub.aclose() + + +class CloseConnectionOnActiveDatabaseChanged(AsyncEventListenerInterface): + """ + Close connection to the old active database. + """ + + async def listen(self, event: AsyncActiveDatabaseChanged): + await event.old_database.client.aclose() + + if isinstance(event.old_database.client, Redis): + await event.old_database.client.connection_pool.update_active_connections_for_reconnect() + await event.old_database.client.connection_pool.disconnect() + + +class RegisterCommandFailure(AsyncEventListenerInterface): + """ + Event listener that registers command failures and passing it to the failure detectors. + """ + + def __init__(self, failure_detectors: List[AsyncFailureDetector]): + self._failure_detectors = failure_detectors + + async def listen(self, event: AsyncOnCommandsFailEvent) -> None: + for failure_detector in self._failure_detectors: + await failure_detector.register_failure(event.exception, event.commands) diff --git a/redis/asyncio/multidb/failover.py b/redis/asyncio/multidb/failover.py new file mode 100644 index 0000000000..5b9202111e --- /dev/null +++ b/redis/asyncio/multidb/failover.py @@ -0,0 +1,125 @@ +import time +from abc import ABC, abstractmethod + +from redis.asyncio.multidb.database import AsyncDatabase, Databases +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import ( + NoValidDatabaseException, + TemporaryUnavailableException, +) + +DEFAULT_FAILOVER_ATTEMPTS = 10 +DEFAULT_FAILOVER_DELAY = 12 + + +class AsyncFailoverStrategy(ABC): + @abstractmethod + async def database(self) -> AsyncDatabase: + """Select the database according to the strategy.""" + pass + + @abstractmethod + def set_databases(self, databases: Databases) -> None: + """Set the database strategy operates on.""" + pass + + +class FailoverStrategyExecutor(ABC): + @property + @abstractmethod + def failover_attempts(self) -> int: + """The number of failover attempts.""" + pass + + @property + @abstractmethod + def failover_delay(self) -> float: + """The delay between failover attempts.""" + pass + + @property + @abstractmethod + def strategy(self) -> AsyncFailoverStrategy: + """The strategy to execute.""" + pass + + @abstractmethod + async def execute(self) -> AsyncDatabase: + """Execute the failover strategy.""" + pass + + +class WeightBasedFailoverStrategy(AsyncFailoverStrategy): + """ + Failover strategy based on database weights. + """ + + def __init__(self): + self._databases = WeightedList() + + async def database(self) -> AsyncDatabase: + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + return database + + raise NoValidDatabaseException("No valid database available for communication") + + def set_databases(self, databases: Databases) -> None: + self._databases = databases + + +class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor): + """ + Executes given failover strategy. + """ + + def __init__( + self, + strategy: AsyncFailoverStrategy, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, + ): + self._strategy = strategy + self._failover_attempts = failover_attempts + self._failover_delay = failover_delay + self._next_attempt_ts: int = 0 + self._failover_counter: int = 0 + + @property + def failover_attempts(self) -> int: + return self._failover_attempts + + @property + def failover_delay(self) -> float: + return self._failover_delay + + @property + def strategy(self) -> AsyncFailoverStrategy: + return self._strategy + + async def execute(self) -> AsyncDatabase: + try: + database = await self._strategy.database() + self._reset() + return database + except NoValidDatabaseException as e: + if self._next_attempt_ts == 0: + self._next_attempt_ts = time.time() + self._failover_delay + self._failover_counter += 1 + elif time.time() >= self._next_attempt_ts: + self._next_attempt_ts += self._failover_delay + self._failover_counter += 1 + + if self._failover_counter > self._failover_attempts: + self._reset() + raise e + else: + raise TemporaryUnavailableException( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + + def _reset(self) -> None: + self._next_attempt_ts = 0 + self._failover_counter = 0 diff --git a/redis/asyncio/multidb/failure_detector.py b/redis/asyncio/multidb/failure_detector.py new file mode 100644 index 0000000000..9c6b61f591 --- /dev/null +++ b/redis/asyncio/multidb/failure_detector.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod + +from redis.multidb.failure_detector import FailureDetector + + +class AsyncFailureDetector(ABC): + @abstractmethod + async def register_failure(self, exception: Exception, cmd: tuple) -> None: + """Register a failure that occurred during command execution.""" + pass + + @abstractmethod + async def register_command_execution(self, cmd: tuple) -> None: + """Register a command execution.""" + pass + + @abstractmethod + def set_command_executor(self, command_executor) -> None: + """Set the command executor for this failure.""" + pass + + +class FailureDetectorAsyncWrapper(AsyncFailureDetector): + """ + Async wrapper for the failure detector. + """ + + def __init__(self, failure_detector: FailureDetector) -> None: + self._failure_detector = failure_detector + + async def register_failure(self, exception: Exception, cmd: tuple) -> None: + self._failure_detector.register_failure(exception, cmd) + + async def register_command_execution(self, cmd: tuple) -> None: + self._failure_detector.register_command_execution(cmd) + + def set_command_executor(self, command_executor) -> None: + self._failure_detector.set_command_executor(command_executor) diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py new file mode 100644 index 0000000000..dcb787f6ed --- /dev/null +++ b/redis/asyncio/multidb/healthcheck.py @@ -0,0 +1,292 @@ +import asyncio +import logging +from abc import ABC, abstractmethod +from enum import Enum +from typing import List, Optional, Tuple, Union + +from redis.asyncio import Redis +from redis.asyncio.http.http_client import DEFAULT_TIMEOUT, AsyncHTTPClientWrapper +from redis.backoff import NoBackoff +from redis.http.http_client import HttpClient +from redis.multidb.exception import UnhealthyDatabaseException +from redis.retry import Retry + +DEFAULT_HEALTH_CHECK_PROBES = 3 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 +DEFAULT_HEALTH_CHECK_DELAY = 0.5 +DEFAULT_LAG_AWARE_TOLERANCE = 5000 + +logger = logging.getLogger(__name__) + + +class HealthCheck(ABC): + @abstractmethod + async def check_health(self, database) -> bool: + """Function to determine the health status.""" + pass + + +class HealthCheckPolicy(ABC): + """ + Health checks execution policy. + """ + + @property + @abstractmethod + def health_check_probes(self) -> int: + """Number of probes to execute health checks.""" + pass + + @property + @abstractmethod + def health_check_delay(self) -> float: + """Delay between health check probes.""" + pass + + @abstractmethod + async def execute(self, health_checks: List[HealthCheck], database) -> bool: + """Execute health checks and return database health status.""" + pass + + +class AbstractHealthCheckPolicy(HealthCheckPolicy): + def __init__(self, health_check_probes: int, health_check_delay: float): + if health_check_probes < 1: + raise ValueError("health_check_probes must be greater than 0") + self._health_check_probes = health_check_probes + self._health_check_delay = health_check_delay + + @property + def health_check_probes(self) -> int: + return self._health_check_probes + + @property + def health_check_delay(self) -> float: + return self._health_check_delay + + @abstractmethod + async def execute(self, health_checks: List[HealthCheck], database) -> bool: + pass + + +class HealthyAllPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if all health check probes are successful. + """ + + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + async def execute(self, health_checks: List[HealthCheck], database) -> bool: + for health_check in health_checks: + for attempt in range(self.health_check_probes): + try: + if not await health_check.check_health(database): + return False + except Exception as e: + raise UnhealthyDatabaseException("Unhealthy database", database, e) + + if attempt < self.health_check_probes - 1: + await asyncio.sleep(self._health_check_delay) + return True + + +class HealthyMajorityPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if a majority of health check probes are successful. + """ + + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + async def execute(self, health_checks: List[HealthCheck], database) -> bool: + for health_check in health_checks: + if self.health_check_probes % 2 == 0: + allowed_unsuccessful_probes = self.health_check_probes / 2 + else: + allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2 + + for attempt in range(self.health_check_probes): + try: + if not await health_check.check_health(database): + allowed_unsuccessful_probes -= 1 + if allowed_unsuccessful_probes <= 0: + return False + except Exception as e: + allowed_unsuccessful_probes -= 1 + if allowed_unsuccessful_probes <= 0: + raise UnhealthyDatabaseException( + "Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + await asyncio.sleep(self._health_check_delay) + return True + + +class HealthyAnyPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if at least one health check probe is successful. + """ + + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + async def execute(self, health_checks: List[HealthCheck], database) -> bool: + is_healthy = False + + for health_check in health_checks: + exception = None + + for attempt in range(self.health_check_probes): + try: + if await health_check.check_health(database): + is_healthy = True + break + else: + is_healthy = False + except Exception as e: + exception = UnhealthyDatabaseException( + "Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + await asyncio.sleep(self._health_check_delay) + + if not is_healthy and not exception: + return is_healthy + elif not is_healthy and exception: + raise exception + + return is_healthy + + +class HealthCheckPolicies(Enum): + HEALTHY_ALL = HealthyAllPolicy + HEALTHY_MAJORITY = HealthyMajorityPolicy + HEALTHY_ANY = HealthyAnyPolicy + + +DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL + + +class EchoHealthCheck(HealthCheck): + """ + Health check based on ECHO command. + """ + + async def check_health(self, database) -> bool: + expected_message = ["healthcheck", b"healthcheck"] + + if isinstance(database.client, Redis): + actual_message = await database.client.execute_command( + "ECHO", "healthcheck" + ) + return actual_message in expected_message + else: + # For a cluster checks if all nodes are healthy. + all_nodes = database.client.get_nodes() + for node in all_nodes: + actual_message = await node.execute_command("ECHO", "healthcheck") + + if actual_message not in expected_message: + return False + + return True + + +class LagAwareHealthCheck(HealthCheck): + """ + Health check available for Redis Enterprise deployments. + Verify via REST API that the database is healthy based on different lags. + """ + + def __init__( + self, + rest_api_port: int = 9443, + lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE, + timeout: float = DEFAULT_TIMEOUT, + auth_basic: Optional[Tuple[str, str]] = None, + verify_tls: bool = True, + # TLS verification (server) options + ca_file: Optional[str] = None, + ca_path: Optional[str] = None, + ca_data: Optional[Union[str, bytes]] = None, + # Mutual TLS (client cert) options + client_cert_file: Optional[str] = None, + client_key_file: Optional[str] = None, + client_key_password: Optional[str] = None, + ): + """ + Initialize LagAwareHealthCheck with the specified parameters. + + Args: + rest_api_port: Port number for Redis Enterprise REST API (default: 9443) + lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100) + timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) + auth_basic: Tuple of (username, password) for basic authentication + verify_tls: Whether to verify TLS certificates (default: True) + ca_file: Path to CA certificate file for TLS verification + ca_path: Path to CA certificates directory for TLS verification + ca_data: CA certificate data as string or bytes + client_cert_file: Path to client certificate file for mutual TLS + client_key_file: Path to client private key file for mutual TLS + client_key_password: Password for encrypted client private key + """ + self._http_client = AsyncHTTPClientWrapper( + HttpClient( + timeout=timeout, + auth_basic=auth_basic, + retry=Retry(NoBackoff(), retries=0), + verify_tls=verify_tls, + ca_file=ca_file, + ca_path=ca_path, + ca_data=ca_data, + client_cert_file=client_cert_file, + client_key_file=client_key_file, + client_key_password=client_key_password, + ) + ) + self._rest_api_port = rest_api_port + self._lag_aware_tolerance = lag_aware_tolerance + + async def check_health(self, database) -> bool: + if database.health_check_url is None: + raise ValueError( + "Database health check url is not set. Please check DatabaseConfig for the current database." + ) + + if isinstance(database.client, Redis): + db_host = database.client.get_connection_kwargs()["host"] + else: + db_host = database.client.startup_nodes[0].host + + base_url = f"{database.health_check_url}:{self._rest_api_port}" + self._http_client.client.base_url = base_url + + # Find bdb matching to the current database host + matching_bdb = None + for bdb in await self._http_client.get("/v1/bdbs"): + for endpoint in bdb["endpoints"]: + if endpoint["dns_name"] == db_host: + matching_bdb = bdb + break + + # In case if the host was set as public IP + for addr in endpoint["addr"]: + if addr == db_host: + matching_bdb = bdb + break + + if matching_bdb is None: + logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") + raise ValueError("Could not find a matching bdb") + + url = ( + f"/v1/bdbs/{matching_bdb['uid']}/availability" + f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}" + ) + await self._http_client.get(url, expect_json=False) + + # Status checked in an http client, otherwise HttpError will be raised + return True diff --git a/redis/background.py b/redis/background.py new file mode 100644 index 0000000000..7d0eead11e --- /dev/null +++ b/redis/background.py @@ -0,0 +1,204 @@ +import asyncio +import threading +from typing import Any, Callable, Coroutine + + +class BackgroundScheduler: + """ + Schedules background tasks execution either in separate thread or in the running event loop. + """ + + def __init__(self): + self._next_timer = None + self._event_loops = [] + self._lock = threading.Lock() + self._stopped = False + + def __del__(self): + self.stop() + + def stop(self): + """ + Stop all scheduled tasks and clean up resources. + """ + with self._lock: + if self._stopped: + return + self._stopped = True + + if self._next_timer: + self._next_timer.cancel() + self._next_timer = None + + # Stop all event loops + for loop in self._event_loops: + if loop.is_running(): + loop.call_soon_threadsafe(loop.stop) + + self._event_loops.clear() + + def run_once(self, delay: float, callback: Callable, *args): + """ + Runs callable task once after certain delay in seconds. + """ + with self._lock: + if self._stopped: + return + + # Run loop in a separate thread to unblock main thread. + loop = asyncio.new_event_loop() + + with self._lock: + self._event_loops.append(loop) + + thread = threading.Thread( + target=_start_event_loop_in_thread, + args=(loop, self._call_later, delay, callback, *args), + daemon=True, + ) + thread.start() + + def run_recurring(self, interval: float, callback: Callable, *args): + """ + Runs recurring callable task with given interval in seconds. + """ + with self._lock: + if self._stopped: + return + + # Run loop in a separate thread to unblock main thread. + loop = asyncio.new_event_loop() + + with self._lock: + self._event_loops.append(loop) + + thread = threading.Thread( + target=_start_event_loop_in_thread, + args=(loop, self._call_later_recurring, interval, callback, *args), + daemon=True, + ) + thread.start() + + async def run_recurring_async( + self, interval: float, coro: Callable[..., Coroutine[Any, Any, Any]], *args + ): + """ + Runs recurring coroutine with given interval in seconds in the current event loop. + To be used only from an async context. No additional threads are created. + """ + with self._lock: + if self._stopped: + return + + loop = asyncio.get_running_loop() + wrapped = _async_to_sync_wrapper(loop, coro, *args) + + def tick(): + with self._lock: + if self._stopped: + return + # Schedule the coroutine + wrapped() + # Schedule next tick + self._next_timer = loop.call_later(interval, tick) + + # Schedule first tick + self._next_timer = loop.call_later(interval, tick) + + def _call_later( + self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args + ): + with self._lock: + if self._stopped: + return + self._next_timer = loop.call_later(delay, callback, *args) + + def _call_later_recurring( + self, + loop: asyncio.AbstractEventLoop, + interval: float, + callback: Callable, + *args, + ): + with self._lock: + if self._stopped: + return + self._call_later( + loop, interval, self._execute_recurring, loop, interval, callback, *args + ) + + def _execute_recurring( + self, + loop: asyncio.AbstractEventLoop, + interval: float, + callback: Callable, + *args, + ): + """ + Executes recurring callable task with given interval in seconds. + """ + with self._lock: + if self._stopped: + return + + try: + callback(*args) + except Exception: + # Silently ignore exceptions during shutdown + pass + + with self._lock: + if self._stopped: + return + + self._call_later( + loop, interval, self._execute_recurring, loop, interval, callback, *args + ) + + +def _start_event_loop_in_thread( + event_loop: asyncio.AbstractEventLoop, call_soon_cb: Callable, *args +): + """ + Starts event loop in a thread and schedule callback as soon as event loop is ready. + Used to be able to schedule tasks using loop.call_later. + + :param event_loop: + :return: + """ + asyncio.set_event_loop(event_loop) + event_loop.call_soon(call_soon_cb, event_loop, *args) + try: + event_loop.run_forever() + finally: + try: + # Clean up pending tasks + pending = asyncio.all_tasks(event_loop) + for task in pending: + task.cancel() + # Run loop once more to process cancellations + event_loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + except Exception: + pass + finally: + event_loop.close() + + +def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs): + """ + Wraps an asynchronous function so it can be used with loop.call_later. + + :param loop: The event loop in which the coroutine will be executed. + :param coro_func: The coroutine function to wrap. + :param args: Positional arguments to pass to the coroutine function. + :param kwargs: Keyword arguments to pass to the coroutine function. + :return: A regular function suitable for loop.call_later. + """ + + def wrapped(): + # Schedule the coroutine in the event loop + asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop) + + return wrapped diff --git a/redis/client.py b/redis/client.py index cf4d77950f..a29d310742 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1271,6 +1271,8 @@ def run_in_thread( sleep_time: float = 0.0, daemon: bool = False, exception_handler: Optional[Callable] = None, + pubsub=None, + sharded_pubsub: bool = False, ) -> "PubSubWorkerThread": for channel, handler in self.channels.items(): if handler is None: @@ -1284,8 +1286,13 @@ def run_in_thread( f"Shard Channel: '{s_channel}' has no handler registered" ) + pubsub = self if pubsub is None else pubsub thread = PubSubWorkerThread( - self, sleep_time, daemon=daemon, exception_handler=exception_handler + pubsub, + sleep_time, + daemon=daemon, + exception_handler=exception_handler, + sharded_pubsub=sharded_pubsub, ) thread.start() return thread @@ -1300,12 +1307,14 @@ def __init__( exception_handler: Union[ Callable[[Exception, "PubSub", "PubSubWorkerThread"], None], None ] = None, + sharded_pubsub: bool = False, ): super().__init__() self.daemon = daemon self.pubsub = pubsub self.sleep_time = sleep_time self.exception_handler = exception_handler + self.sharded_pubsub = sharded_pubsub self._running = threading.Event() def run(self) -> None: @@ -1316,7 +1325,14 @@ def run(self) -> None: sleep_time = self.sleep_time while self._running.is_set(): try: - pubsub.get_message(ignore_subscribe_messages=True, timeout=sleep_time) + if not self.sharded_pubsub: + pubsub.get_message( + ignore_subscribe_messages=True, timeout=sleep_time + ) + else: + pubsub.get_sharded_message( + ignore_subscribe_messages=True, timeout=sleep_time + ) except BaseException as e: if self.exception_handler is None: raise diff --git a/redis/cluster.py b/redis/cluster.py index 839721edf1..1d4a3e0d0c 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -695,6 +695,7 @@ def __init__( self._event_dispatcher = EventDispatcher() else: self._event_dispatcher = event_dispatcher + self.startup_nodes = startup_nodes self.nodes_manager = NodesManager( startup_nodes=startup_nodes, from_url=from_url, @@ -3164,7 +3165,8 @@ def _reinitialize_on_error(self, error): self._nodes_manager.initialize() self.reinitialize_counter = 0 else: - self._nodes_manager.update_moved_exception(error) + if isinstance(error, AskError): + self._nodes_manager.update_moved_exception(error) self._executing = False diff --git a/redis/data_structure.py b/redis/data_structure.py new file mode 100644 index 0000000000..0571e223ad --- /dev/null +++ b/redis/data_structure.py @@ -0,0 +1,81 @@ +import threading +from typing import Any, Generic, List, TypeVar + +from redis.typing import Number + +T = TypeVar("T") + + +class WeightedList(Generic[T]): + """ + Thread-safe weighted list. + """ + + def __init__(self): + self._items: List[tuple[Any, Number]] = [] + self._lock = threading.RLock() + + def add(self, item: Any, weight: float) -> None: + """Add item with weight, maintaining sorted order""" + with self._lock: + # Find insertion point using binary search + left, right = 0, len(self._items) + while left < right: + mid = (left + right) // 2 + if self._items[mid][1] < weight: + right = mid + else: + left = mid + 1 + + self._items.insert(left, (item, weight)) + + def remove(self, item): + """Remove first occurrence of item""" + with self._lock: + for i, (stored_item, weight) in enumerate(self._items): + if stored_item == item: + self._items.pop(i) + return weight + raise ValueError("Item not found") + + def get_by_weight_range( + self, min_weight: float, max_weight: float + ) -> List[tuple[Any, Number]]: + """Get all items within weight range""" + with self._lock: + result = [] + for item, weight in self._items: + if min_weight <= weight <= max_weight: + result.append((item, weight)) + return result + + def get_top_n(self, n: int) -> List[tuple[Any, Number]]: + """Get top N the highest weighted items""" + with self._lock: + return [(item, weight) for item, weight in self._items[:n]] + + def update_weight(self, item, new_weight: float): + with self._lock: + """Update weight of an item""" + old_weight = self.remove(item) + self.add(item, new_weight) + return old_weight + + def __iter__(self): + """Iterate in descending weight order""" + with self._lock: + items_copy = ( + self._items.copy() + ) # Create snapshot as lock released after each 'yield' + + for item, weight in items_copy: + yield item, weight + + def __len__(self): + with self._lock: + return len(self._items) + + def __getitem__(self, index) -> tuple[Any, Number]: + with self._lock: + item, weight = self._items[index] + return item, weight diff --git a/redis/event.py b/redis/event.py index b86c66b082..03c72c6370 100644 --- a/redis/event.py +++ b/redis/event.py @@ -2,7 +2,7 @@ import threading from abc import ABC, abstractmethod from enum import Enum -from typing import List, Optional, Union +from typing import Dict, List, Optional, Type, Union from redis.auth.token import TokenInterface from redis.credentials import CredentialProvider, StreamingCredentialProvider @@ -42,6 +42,17 @@ def dispatch(self, event: object): async def dispatch_async(self, event: object): pass + @abstractmethod + def register_listeners( + self, + mappings: Dict[ + Type[object], + List[Union[EventListenerInterface, AsyncEventListenerInterface]], + ], + ): + """Register additional listeners.""" + pass + class EventException(Exception): """ @@ -56,11 +67,18 @@ def __init__(self, exception: Exception, event: object): class EventDispatcher(EventDispatcherInterface): # TODO: Make dispatcher to accept external mappings. - def __init__(self): + def __init__( + self, + event_listeners: Optional[ + Dict[Type[object], List[EventListenerInterface]] + ] = None, + ): """ - Mapping should be extended for any new events or listeners to be added. + Dispatcher that dispatches events to listeners associated with given event. """ - self._event_listeners_mapping = { + self._event_listeners_mapping: Dict[ + Type[object], List[EventListenerInterface] + ] = { AfterConnectionReleasedEvent: [ ReAuthConnectionListener(), ], @@ -77,17 +95,47 @@ def __init__(self): ], } + self._lock = threading.Lock() + self._async_lock = None + + if event_listeners: + self.register_listeners(event_listeners) + def dispatch(self, event: object): - listeners = self._event_listeners_mapping.get(type(event)) + with self._lock: + listeners = self._event_listeners_mapping.get(type(event), []) - for listener in listeners: - listener.listen(event) + for listener in listeners: + listener.listen(event) async def dispatch_async(self, event: object): - listeners = self._event_listeners_mapping.get(type(event)) + if self._async_lock is None: + self._async_lock = asyncio.Lock() + + async with self._async_lock: + listeners = self._event_listeners_mapping.get(type(event), []) - for listener in listeners: - await listener.listen(event) + for listener in listeners: + await listener.listen(event) + + def register_listeners( + self, + mappings: Dict[ + Type[object], + List[Union[EventListenerInterface, AsyncEventListenerInterface]], + ], + ): + with self._lock: + for event_type in mappings: + if event_type in self._event_listeners_mapping: + self._event_listeners_mapping[event_type] = list( + set( + self._event_listeners_mapping[event_type] + + mappings[event_type] + ) + ) + else: + self._event_listeners_mapping[event_type] = mappings[event_type] class AfterConnectionReleasedEvent: @@ -226,6 +274,32 @@ def credential_provider(self) -> Union[CredentialProvider, None]: return self._credential_provider +class OnCommandsFailEvent: + """ + Event fired whenever a command fails during the execution. + """ + + def __init__( + self, + commands: tuple, + exception: Exception, + ): + self._commands = commands + self._exception = exception + + @property + def commands(self) -> tuple: + return self._commands + + @property + def exception(self) -> Exception: + return self._exception + + +class AsyncOnCommandsFailEvent(OnCommandsFailEvent): + pass + + class ReAuthConnectionListener(EventListenerInterface): """ Listener that performs re-authentication of given connection. diff --git a/redis/http/__init__.py b/redis/http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/http/http_client.py b/redis/http/http_client.py new file mode 100644 index 0000000000..7d9d5c4ad4 --- /dev/null +++ b/redis/http/http_client.py @@ -0,0 +1,425 @@ +from __future__ import annotations + +import base64 +import gzip +import json +import ssl +import zlib +from dataclasses import dataclass +from typing import Any, Dict, Mapping, Optional, Tuple, Union +from urllib.error import HTTPError, URLError +from urllib.parse import urlencode, urljoin +from urllib.request import Request, urlopen + +__all__ = ["HttpClient", "HttpResponse", "HttpError", "DEFAULT_TIMEOUT"] + +from redis.backoff import ExponentialWithJitterBackoff +from redis.retry import Retry +from redis.utils import dummy_fail + +DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)" +DEFAULT_TIMEOUT = 30.0 +RETRY_STATUS_CODES = {429, 500, 502, 503, 504} + + +@dataclass +class HttpResponse: + status: int + headers: Dict[str, str] + url: str + content: bytes + + def text(self, encoding: Optional[str] = None) -> str: + enc = encoding or self._get_encoding() + return self.content.decode(enc, errors="replace") + + def json(self) -> Any: + return json.loads(self.text(encoding=self._get_encoding())) + + def _get_encoding(self) -> str: + # Try to infer encoding from headers; default to utf-8 + ctype = self.headers.get("content-type", "") + # Example: application/json; charset=utf-8 + for part in ctype.split(";"): + p = part.strip() + if p.lower().startswith("charset="): + return p.split("=", 1)[1].strip() or "utf-8" + return "utf-8" + + +class HttpError(Exception): + def __init__(self, status: int, url: str, message: Optional[str] = None): + self.status = status + self.url = url + self.message = message or f"HTTP {status} for {url}" + super().__init__(self.message) + + +class HttpClient: + """ + A lightweight HTTP client for REST API calls. + """ + + def __init__( + self, + base_url: str = "", + headers: Optional[Mapping[str, str]] = None, + timeout: float = DEFAULT_TIMEOUT, + retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ), + verify_tls: bool = True, + # TLS verification (server) options + ca_file: Optional[str] = None, + ca_path: Optional[str] = None, + ca_data: Optional[Union[str, bytes]] = None, + # Mutual TLS (client cert) options + client_cert_file: Optional[str] = None, + client_key_file: Optional[str] = None, + client_key_password: Optional[str] = None, + auth_basic: Optional[Tuple[str, str]] = None, # (username, password) + user_agent: str = DEFAULT_USER_AGENT, + ) -> None: + """ + Initialize a new HTTP client instance. + + Args: + base_url: Base URL for all requests. Will be prefixed to all paths. + headers: Default headers to include in all requests. + timeout: Default timeout in seconds for requests. + retry: Retry configuration for failed requests. + verify_tls: Whether to verify TLS certificates. + ca_file: Path to CA certificate file for TLS verification. + ca_path: Path to a directory containing CA certificates. + ca_data: CA certificate data as string or bytes. + client_cert_file: Path to client certificate for mutual TLS. + client_key_file: Path to a client private key for mutual TLS. + client_key_password: Password for an encrypted client private key. + auth_basic: Tuple of (username, password) for HTTP basic auth. + user_agent: User-Agent header value for requests. + + The client supports both regular HTTPS with server verification and mutual TLS + authentication. For server verification, provide CA certificate information via + ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client + certificate and key via client_cert_file and client_key_file. + """ + self.base_url = ( + base_url.rstrip() + "/" + if base_url and not base_url.endswith("/") + else base_url + ) + self._default_headers = {k.lower(): v for k, v in (headers or {}).items()} + self.timeout = timeout + self.retry = retry + self.retry.update_supported_errors((HTTPError, URLError, ssl.SSLError)) + self.verify_tls = verify_tls + + # TLS settings + self.ca_file = ca_file + self.ca_path = ca_path + self.ca_data = ca_data + self.client_cert_file = client_cert_file + self.client_key_file = client_key_file + self.client_key_password = client_key_password + + self.auth_basic = auth_basic + self.user_agent = user_agent + + # Public JSON-centric helpers + def get( + self, + path: str, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + return self._json_call( + "GET", + path, + params=params, + headers=headers, + timeout=timeout, + body=None, + expect_json=expect_json, + ) + + def delete( + self, + path: str, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + return self._json_call( + "DELETE", + path, + params=params, + headers=headers, + timeout=timeout, + body=None, + expect_json=expect_json, + ) + + def post( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + return self._json_call( + "POST", + path, + params=params, + headers=headers, + timeout=timeout, + body=self._prepare_body(json_body=json_body, data=data), + expect_json=expect_json, + ) + + def put( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + return self._json_call( + "PUT", + path, + params=params, + headers=headers, + timeout=timeout, + body=self._prepare_body(json_body=json_body, data=data), + expect_json=expect_json, + ) + + def patch( + self, + path: str, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + return self._json_call( + "PATCH", + path, + params=params, + headers=headers, + timeout=timeout, + body=self._prepare_body(json_body=json_body, data=data), + expect_json=expect_json, + ) + + # Low-level request + def request( + self, + method: str, + path: str, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + body: Optional[Union[bytes, str]] = None, + timeout: Optional[float] = None, + ) -> HttpResponse: + url = self._build_url(path, params) + all_headers = self._prepare_headers(headers, body) + data = body.encode("utf-8") if isinstance(body, str) else body + + req = Request(url=url, method=method.upper(), data=data, headers=all_headers) + + context: Optional[ssl.SSLContext] = None + if url.lower().startswith("https"): + if self.verify_tls: + # Use provided CA material if any; fall back to system defaults + context = ssl.create_default_context( + cafile=self.ca_file, + capath=self.ca_path, + cadata=self.ca_data, + ) + # Load client certificate for mTLS if configured + if self.client_cert_file: + context.load_cert_chain( + certfile=self.client_cert_file, + keyfile=self.client_key_file, + password=self.client_key_password, + ) + else: + # Verification disabled + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + try: + return self.retry.call_with_retry( + lambda: self._make_request(req, context=context, timeout=timeout), + lambda _: dummy_fail(), + lambda error: self._is_retryable_http_error(error), + ) + except HTTPError as e: + # Read error body, build response, and decide on retry + err_body = b"" + try: + err_body = e.read() + except Exception: + pass + headers_map = {k.lower(): v for k, v in (e.headers or {}).items()} + err_body = self._maybe_decompress(err_body, headers_map) + status = getattr(e, "code", 0) or 0 + response = HttpResponse( + status=status, + headers=headers_map, + url=url, + content=err_body, + ) + return response + + def _make_request( + self, + request: Request, + context: Optional[ssl.SSLContext] = None, + timeout: Optional[float] = None, + ): + with urlopen(request, timeout=timeout or self.timeout, context=context) as resp: + raw = resp.read() + headers_map = {k.lower(): v for k, v in resp.headers.items()} + raw = self._maybe_decompress(raw, headers_map) + return HttpResponse( + status=resp.status, + headers=headers_map, + url=resp.geturl(), + content=raw, + ) + + def _is_retryable_http_error(self, error: Exception) -> bool: + if isinstance(error, HTTPError): + return self._should_retry_status(error.code) + return False + + # Internal utilities + def _json_call( + self, + method: str, + path: str, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + body: Optional[Union[bytes, str]] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + resp = self.request( + method=method, + path=path, + params=params, + headers=headers, + body=body, + timeout=timeout, + ) + if not (200 <= resp.status < 400): + raise HttpError(resp.status, resp.url, resp.text()) + if expect_json: + return resp.json() + return resp + + def _prepare_body( + self, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None + ) -> Optional[Union[bytes, str]]: + if json_body is not None and data is not None: + raise ValueError("Provide either json_body or data, not both.") + if json_body is not None: + return json.dumps(json_body, ensure_ascii=False, separators=(",", ":")) + return data + + def _build_url( + self, + path: str, + params: Optional[ + Mapping[str, Union[None, str, int, float, bool, list, tuple]] + ] = None, + ) -> str: + url = urljoin(self.base_url or "", path) + if params: + # urlencode with doseq=True supports list/tuple values + query = urlencode( + {k: v for k, v in params.items() if v is not None}, doseq=True + ) + separator = "&" if ("?" in url) else "?" + url = f"{url}{separator}{query}" if query else url + return url + + def _prepare_headers( + self, headers: Optional[Mapping[str, str]], body: Optional[Union[bytes, str]] + ) -> Dict[str, str]: + # Start with defaults + prepared: Dict[str, str] = {} + prepared.update(self._default_headers) + + # Standard defaults for JSON REST usage + prepared.setdefault("accept", "application/json") + prepared.setdefault("user-agent", self.user_agent) + # We will send gzip accept-encoding; handle decompression manually + prepared.setdefault("accept-encoding", "gzip, deflate") + + # If we have a string body and content-type not specified, assume JSON + if body is not None and isinstance(body, str): + prepared.setdefault("content-type", "application/json; charset=utf-8") + + # Basic authentication if provided and not overridden + if self.auth_basic and "authorization" not in prepared: + user, pwd = self.auth_basic + token = base64.b64encode(f"{user}:{pwd}".encode("utf-8")).decode("ascii") + prepared["authorization"] = f"Basic {token}" + + # Merge per-call headers (case-insensitive) + if headers: + for k, v in headers.items(): + prepared[k.lower()] = v + + # urllib expects header keys in canonical capitalization sometimes; but it’s tolerant. + # We'll return as provided; urllib will handle it. + return prepared + + def _should_retry_status(self, status: int) -> bool: + return status in RETRY_STATUS_CODES + + def _maybe_decompress(self, content: bytes, headers: Mapping[str, str]) -> bytes: + if not content: + return content + encoding = (headers.get("content-encoding") or "").lower() + try: + if "gzip" in encoding: + return gzip.decompress(content) + if "deflate" in encoding: + # Try raw deflate, then zlib-wrapped + try: + return zlib.decompress(content, -zlib.MAX_WBITS) + except zlib.error: + return zlib.decompress(content) + except Exception: + # If decompression fails, return original bytes + return content + return content diff --git a/redis/multidb/__init__.py b/redis/multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py new file mode 100644 index 0000000000..8af6cc32de --- /dev/null +++ b/redis/multidb/circuit.py @@ -0,0 +1,144 @@ +from abc import ABC, abstractmethod +from enum import Enum +from typing import Callable + +import pybreaker + +DEFAULT_GRACE_PERIOD = 60 + + +class State(Enum): + CLOSED = "closed" + OPEN = "open" + HALF_OPEN = "half-open" + + +class CircuitBreaker(ABC): + @property + @abstractmethod + def grace_period(self) -> float: + """The grace period in seconds when the circle should be kept open.""" + pass + + @grace_period.setter + @abstractmethod + def grace_period(self, grace_period: float): + """Set the grace period in seconds.""" + + @property + @abstractmethod + def state(self) -> State: + """The current state of the circuit.""" + pass + + @state.setter + @abstractmethod + def state(self, state: State): + """Set current state of the circuit.""" + pass + + @property + @abstractmethod + def database(self): + """Database associated with this circuit.""" + pass + + @database.setter + @abstractmethod + def database(self, database): + """Set database associated with this circuit.""" + pass + + @abstractmethod + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + """Callback called when the state of the circuit changes.""" + pass + + +class BaseCircuitBreaker(CircuitBreaker): + """ + Base implementation of Circuit Breaker interface. + """ + + def __init__(self, cb: pybreaker.CircuitBreaker): + self._cb = cb + self._state_pb_mapper = { + State.CLOSED: self._cb.close, + State.OPEN: self._cb.open, + State.HALF_OPEN: self._cb.half_open, + } + self._database = None + + @property + def grace_period(self) -> float: + return self._cb.reset_timeout + + @grace_period.setter + def grace_period(self, grace_period: float): + self._cb.reset_timeout = grace_period + + @property + def state(self) -> State: + return State(value=self._cb.state.name) + + @state.setter + def state(self, state: State): + self._state_pb_mapper[state]() + + @property + def database(self): + return self._database + + @database.setter + def database(self, database): + self._database = database + + @abstractmethod + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + """Callback called when the state of the circuit changes.""" + pass + + +class PBListener(pybreaker.CircuitBreakerListener): + """Wrapper for callback to be compatible with pybreaker implementation.""" + + def __init__( + self, + cb: Callable[[CircuitBreaker, State, State], None], + database, + ): + """ + Initialize a PBListener instance. + + Args: + cb: Callback function that will be called when the circuit breaker state changes. + database: Database instance associated with this circuit breaker. + """ + + self._cb = cb + self._database = database + + def state_change(self, cb, old_state, new_state): + cb = PBCircuitBreakerAdapter(cb) + cb.database = self._database + old_state = State(value=old_state.name) + new_state = State(value=new_state.name) + self._cb(cb, old_state, new_state) + + +class PBCircuitBreakerAdapter(BaseCircuitBreaker): + def __init__(self, cb: pybreaker.CircuitBreaker): + """ + Initialize a PBCircuitBreakerAdapter instance. + + This adapter wraps pybreaker's CircuitBreaker implementation to make it compatible + with our CircuitBreaker interface. + + Args: + cb: A pybreaker CircuitBreaker instance to be adapted. + """ + super().__init__(cb) + + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + listener = PBListener(cb, self.database) + self._cb.add_listener(listener) diff --git a/redis/multidb/client.py b/redis/multidb/client.py new file mode 100644 index 0000000000..485174fc03 --- /dev/null +++ b/redis/multidb/client.py @@ -0,0 +1,524 @@ +import logging +import threading +from concurrent.futures import as_completed +from concurrent.futures.thread import ThreadPoolExecutor +from typing import Any, Callable, List, Optional + +from redis.background import BackgroundScheduler +from redis.client import PubSubWorkerThread +from redis.commands import CoreCommands, RedisModuleCommands +from redis.multidb.circuit import CircuitBreaker +from redis.multidb.circuit import State as CBState +from redis.multidb.command_executor import DefaultCommandExecutor +from redis.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig +from redis.multidb.database import Database, Databases, SyncDatabase +from redis.multidb.exception import NoValidDatabaseException, UnhealthyDatabaseException +from redis.multidb.failure_detector import FailureDetector +from redis.multidb.healthcheck import HealthCheck, HealthCheckPolicy + +logger = logging.getLogger(__name__) + + +class MultiDBClient(RedisModuleCommands, CoreCommands): + """ + Client that operates on multiple logical Redis databases. + Should be used in Active-Active database setups. + """ + + def __init__(self, config: MultiDbConfig): + self._databases = config.databases() + self._health_checks = config.default_health_checks() + + if config.health_checks is not None: + self._health_checks.extend(config.health_checks) + + self._health_check_interval = config.health_check_interval + self._health_check_policy: HealthCheckPolicy = config.health_check_policy.value( + config.health_check_probes, config.health_check_delay + ) + self._failure_detectors = config.default_failure_detectors() + + if config.failure_detectors is not None: + self._failure_detectors.extend(config.failure_detectors) + + self._failover_strategy = ( + config.default_failover_strategy() + if config.failover_strategy is None + else config.failover_strategy + ) + self._failover_strategy.set_databases(self._databases) + self._auto_fallback_interval = config.auto_fallback_interval + self._event_dispatcher = config.event_dispatcher + self._command_retry = config.command_retry + self._command_retry.update_supported_errors((ConnectionRefusedError,)) + self.command_executor = DefaultCommandExecutor( + failure_detectors=self._failure_detectors, + databases=self._databases, + command_retry=self._command_retry, + failover_strategy=self._failover_strategy, + failover_attempts=config.failover_attempts, + failover_delay=config.failover_delay, + event_dispatcher=self._event_dispatcher, + auto_fallback_interval=self._auto_fallback_interval, + ) + self.initialized = False + self._hc_lock = threading.RLock() + self._bg_scheduler = BackgroundScheduler() + self._config = config + + def initialize(self): + """ + Perform initialization of databases to define their initial state. + """ + + def raise_exception_on_failed_hc(error): + raise error + + # Initial databases check to define initial state + self._check_databases_health(on_error=raise_exception_on_failed_hc) + + # Starts recurring health checks on the background. + self._bg_scheduler.run_recurring( + self._health_check_interval, + self._check_databases_health, + ) + + is_active_db_found = False + + for database, weight in self._databases: + # Set on state changed callback for each circuit. + database.circuit.on_state_changed(self._on_circuit_state_change_callback) + + # Set states according to a weights and circuit state + if database.circuit.state == CBState.CLOSED and not is_active_db_found: + self.command_executor.active_database = database + is_active_db_found = True + + if not is_active_db_found: + raise NoValidDatabaseException( + "Initial connection failed - no active database found" + ) + + self.initialized = True + + def get_databases(self) -> Databases: + """ + Returns a sorted (by weight) list of all databases. + """ + return self._databases + + def set_active_database(self, database: SyncDatabase) -> None: + """ + Promote one of the existing databases to become an active. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError("Given database is not a member of database list") + + self._check_db_health(database) + + if database.circuit.state == CBState.CLOSED: + highest_weighted_db, _ = self._databases.get_top_n(1)[0] + self.command_executor.active_database = database + return + + raise NoValidDatabaseException( + "Cannot set active database, database is unhealthy" + ) + + def add_database(self, database: SyncDatabase): + """ + Adds a new database to the database list. + """ + for existing_db, _ in self._databases: + if existing_db == database: + raise ValueError("Given database already exists") + + self._check_db_health(database) + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.add(database, database.weight) + self._change_active_database(database, highest_weighted_db) + + def _change_active_database( + self, new_database: SyncDatabase, highest_weight_database: SyncDatabase + ): + if ( + new_database.weight > highest_weight_database.weight + and new_database.circuit.state == CBState.CLOSED + ): + self.command_executor.active_database = new_database + + def remove_database(self, database: Database): + """ + Removes a database from the database list. + """ + weight = self._databases.remove(database) + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + + if ( + highest_weight <= weight + and highest_weighted_db.circuit.state == CBState.CLOSED + ): + self.command_executor.active_database = highest_weighted_db + + def update_database_weight(self, database: SyncDatabase, weight: float): + """ + Updates a database from the database list. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError("Given database is not a member of database list") + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.update_weight(database, weight) + database.weight = weight + self._change_active_database(database, highest_weighted_db) + + def add_failure_detector(self, failure_detector: FailureDetector): + """ + Adds a new failure detector to the database. + """ + self._failure_detectors.append(failure_detector) + + def add_health_check(self, healthcheck: HealthCheck): + """ + Adds a new health check to the database. + """ + with self._hc_lock: + self._health_checks.append(healthcheck) + + def execute_command(self, *args, **options): + """ + Executes a single command and return its result. + """ + if not self.initialized: + self.initialize() + + return self.command_executor.execute_command(*args, **options) + + def pipeline(self): + """ + Enters into pipeline mode of the client. + """ + return Pipeline(self) + + def transaction(self, func: Callable[["Pipeline"], None], *watches, **options): + """ + Executes callable as transaction. + """ + if not self.initialized: + self.initialize() + + return self.command_executor.execute_transaction(func, *watches, *options) + + def pubsub(self, **kwargs): + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + if not self.initialized: + self.initialize() + + return PubSub(self, **kwargs) + + def _check_db_health(self, database: SyncDatabase) -> bool: + """ + Runs health checks on the given database until first failure. + """ + # Health check will setup circuit state + is_healthy = self._health_check_policy.execute(self._health_checks, database) + + if not is_healthy: + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + return is_healthy + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED + + return is_healthy + + def _check_databases_health(self, on_error: Callable[[Exception], None] = None): + """ + Runs health checks as a recurring task. + Runs health checks against all databases. + """ + with ThreadPoolExecutor(max_workers=len(self._databases)) as executor: + # Submit all health checks + futures = { + executor.submit(self._check_db_health, database) + for database, _ in self._databases + } + + try: + for future in as_completed( + futures, timeout=self._health_check_interval + ): + try: + future.result() + except UnhealthyDatabaseException as e: + unhealthy_db = e.database + unhealthy_db.circuit.state = CBState.OPEN + + logger.exception( + "Health check failed, due to exception", + exc_info=e.original_exception, + ) + + if on_error: + on_error(e.original_exception) + except TimeoutError: + raise TimeoutError( + "Health check execution exceeds health_check_interval" + ) + + def _on_circuit_state_change_callback( + self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState + ): + if new_state == CBState.HALF_OPEN: + self._check_db_health(circuit.database) + return + + if old_state == CBState.CLOSED and new_state == CBState.OPEN: + self._bg_scheduler.run_once( + DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit + ) + + def close(self): + self.command_executor.active_database.client.close() + + +def _half_open_circuit(circuit: CircuitBreaker): + circuit.state = CBState.HALF_OPEN + + +class Pipeline(RedisModuleCommands, CoreCommands): + """ + Pipeline implementation for multiple logical Redis databases. + """ + + def __init__(self, client: MultiDBClient): + self._command_stack = [] + self._client = client + + def __enter__(self) -> "Pipeline": + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.reset() + + def __del__(self): + try: + self.reset() + except Exception: + pass + + def __len__(self) -> int: + return len(self._command_stack) + + def __bool__(self) -> bool: + """Pipeline instances should always evaluate to True""" + return True + + def reset(self) -> None: + self._command_stack = [] + + def close(self) -> None: + """Close the pipeline""" + self.reset() + + def pipeline_execute_command(self, *args, **options) -> "Pipeline": + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self._command_stack.append((args, options)) + return self + + def execute_command(self, *args, **kwargs): + """Adds a command to the stack""" + return self.pipeline_execute_command(*args, **kwargs) + + def execute(self) -> List[Any]: + """Execute all the commands in the current pipeline""" + if not self._client.initialized: + self._client.initialize() + + try: + return self._client.command_executor.execute_pipeline( + tuple(self._command_stack) + ) + finally: + self.reset() + + +class PubSub: + """ + PubSub object for multi database client. + """ + + def __init__(self, client: MultiDBClient, **kwargs): + """Initialize the PubSub object for a multi-database client. + + Args: + client: MultiDBClient instance to use for pub/sub operations + **kwargs: Additional keyword arguments to pass to the underlying pubsub implementation + """ + + self._client = client + self._client.command_executor.pubsub(**kwargs) + + def __enter__(self) -> "PubSub": + return self + + def __del__(self) -> None: + try: + # if this object went out of scope prior to shutting down + # subscriptions, close the connection manually before + # returning it to the connection pool + self.reset() + except Exception: + pass + + def reset(self) -> None: + return self._client.command_executor.execute_pubsub_method("reset") + + def close(self) -> None: + self.reset() + + @property + def subscribed(self) -> bool: + return self._client.command_executor.active_pubsub.subscribed + + def execute_command(self, *args): + return self._client.command_executor.execute_pubsub_method( + "execute_command", *args + ) + + def psubscribe(self, *args, **kwargs): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + return self._client.command_executor.execute_pubsub_method( + "psubscribe", *args, **kwargs + ) + + def punsubscribe(self, *args): + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + return self._client.command_executor.execute_pubsub_method( + "punsubscribe", *args + ) + + def subscribe(self, *args, **kwargs): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + return self._client.command_executor.execute_pubsub_method( + "subscribe", *args, **kwargs + ) + + def unsubscribe(self, *args): + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + return self._client.command_executor.execute_pubsub_method("unsubscribe", *args) + + def ssubscribe(self, *args, **kwargs): + """ + Subscribes the client to the specified shard channels. + Channels supplied as keyword arguments expect a channel name as the key + and a callable as the value. A channel's callable will be invoked automatically + when a message is received on that channel rather than producing a message via + ``listen()`` or ``get_sharded_message()``. + """ + return self._client.command_executor.execute_pubsub_method( + "ssubscribe", *args, **kwargs + ) + + def sunsubscribe(self, *args): + """ + Unsubscribe from the supplied shard_channels. If empty, unsubscribe from + all shard_channels + """ + return self._client.command_executor.execute_pubsub_method( + "sunsubscribe", *args + ) + + def get_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number, or None, to wait indefinitely. + """ + return self._client.command_executor.execute_pubsub_method( + "get_message", + ignore_subscribe_messages=ignore_subscribe_messages, + timeout=timeout, + ) + + def get_sharded_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available in a sharded channel, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number, or None, to wait indefinitely. + """ + return self._client.command_executor.execute_pubsub_method( + "get_sharded_message", + ignore_subscribe_messages=ignore_subscribe_messages, + timeout=timeout, + ) + + def run_in_thread( + self, + sleep_time: float = 0.0, + daemon: bool = False, + exception_handler: Optional[Callable] = None, + sharded_pubsub: bool = False, + ) -> "PubSubWorkerThread": + return self._client.command_executor.execute_pubsub_run( + sleep_time, + daemon=daemon, + exception_handler=exception_handler, + pubsub=self, + sharded_pubsub=sharded_pubsub, + ) diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py new file mode 100644 index 0000000000..f8e6171bc8 --- /dev/null +++ b/redis/multidb/command_executor.py @@ -0,0 +1,350 @@ +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import Any, Callable, List, Optional + +from redis.client import Pipeline, PubSub, PubSubWorkerThread +from redis.event import EventDispatcherInterface, OnCommandsFailEvent +from redis.multidb.circuit import State as CBState +from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.multidb.database import Database, Databases, SyncDatabase +from redis.multidb.event import ( + ActiveDatabaseChanged, + CloseConnectionOnActiveDatabaseChanged, + RegisterCommandFailure, + ResubscribeOnActiveDatabaseChanged, +) +from redis.multidb.failover import ( + DEFAULT_FAILOVER_ATTEMPTS, + DEFAULT_FAILOVER_DELAY, + DefaultFailoverStrategyExecutor, + FailoverStrategy, + FailoverStrategyExecutor, +) +from redis.multidb.failure_detector import FailureDetector +from redis.retry import Retry + + +class CommandExecutor(ABC): + @property + @abstractmethod + def auto_fallback_interval(self) -> float: + """Returns auto-fallback interval.""" + pass + + @auto_fallback_interval.setter + @abstractmethod + def auto_fallback_interval(self, auto_fallback_interval: float) -> None: + """Sets auto-fallback interval.""" + pass + + +class BaseCommandExecutor(CommandExecutor): + def __init__( + self, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + self._auto_fallback_interval = auto_fallback_interval + self._next_fallback_attempt: datetime + + @property + def auto_fallback_interval(self) -> float: + return self._auto_fallback_interval + + @auto_fallback_interval.setter + def auto_fallback_interval(self, auto_fallback_interval: int) -> None: + self._auto_fallback_interval = auto_fallback_interval + + def _schedule_next_fallback(self) -> None: + if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: + return + + self._next_fallback_attempt = datetime.now() + timedelta( + seconds=self._auto_fallback_interval + ) + + +class SyncCommandExecutor(CommandExecutor): + @property + @abstractmethod + def databases(self) -> Databases: + """Returns a list of databases.""" + pass + + @property + @abstractmethod + def failure_detectors(self) -> List[FailureDetector]: + """Returns a list of failure detectors.""" + pass + + @abstractmethod + def add_failure_detector(self, failure_detector: FailureDetector) -> None: + """Adds a new failure detector to the list of failure detectors.""" + pass + + @property + @abstractmethod + def active_database(self) -> Optional[Database]: + """Returns currently active database.""" + pass + + @active_database.setter + @abstractmethod + def active_database(self, database: SyncDatabase) -> None: + """Sets the currently active database.""" + pass + + @property + @abstractmethod + def active_pubsub(self) -> Optional[PubSub]: + """Returns currently active pubsub.""" + pass + + @active_pubsub.setter + @abstractmethod + def active_pubsub(self, pubsub: PubSub) -> None: + """Sets currently active pubsub.""" + pass + + @property + @abstractmethod + def failover_strategy_executor(self) -> FailoverStrategyExecutor: + """Returns failover strategy executor.""" + pass + + @property + @abstractmethod + def command_retry(self) -> Retry: + """Returns command retry object.""" + pass + + @abstractmethod + def pubsub(self, **kwargs): + """Initializes a PubSub object on a currently active database""" + pass + + @abstractmethod + def execute_command(self, *args, **options): + """Executes a command and returns the result.""" + pass + + @abstractmethod + def execute_pipeline(self, command_stack: tuple): + """Executes a stack of commands in pipeline.""" + pass + + @abstractmethod + def execute_transaction( + self, transaction: Callable[[Pipeline], None], *watches, **options + ): + """Executes a transaction block wrapped in callback.""" + pass + + @abstractmethod + def execute_pubsub_method(self, method_name: str, *args, **kwargs): + """Executes a given method on active pub/sub.""" + pass + + @abstractmethod + def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + """Executes pub/sub run in a thread.""" + pass + + +class DefaultCommandExecutor(SyncCommandExecutor, BaseCommandExecutor): + def __init__( + self, + failure_detectors: List[FailureDetector], + databases: Databases, + command_retry: Retry, + failover_strategy: FailoverStrategy, + event_dispatcher: EventDispatcherInterface, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + """ + Initialize the DefaultCommandExecutor instance. + + Args: + failure_detectors: List of failure detector instances to monitor database health + databases: Collection of available databases to execute commands on + command_retry: Retry policy for failed command execution + failover_strategy: Strategy for handling database failover + event_dispatcher: Interface for dispatching events + failover_attempts: Number of failover attempts + failover_delay: Delay between failover attempts + auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database + """ + super().__init__(auto_fallback_interval) + + for fd in failure_detectors: + fd.set_command_executor(command_executor=self) + + self._databases = databases + self._failure_detectors = failure_detectors + self._command_retry = command_retry + self._failover_strategy_executor = DefaultFailoverStrategyExecutor( + failover_strategy, failover_attempts, failover_delay + ) + self._event_dispatcher = event_dispatcher + self._active_database: Optional[Database] = None + self._active_pubsub: Optional[PubSub] = None + self._active_pubsub_kwargs = {} + self._setup_event_dispatcher() + self._schedule_next_fallback() + + @property + def databases(self) -> Databases: + return self._databases + + @property + def failure_detectors(self) -> List[FailureDetector]: + return self._failure_detectors + + def add_failure_detector(self, failure_detector: FailureDetector) -> None: + self._failure_detectors.append(failure_detector) + + @property + def command_retry(self) -> Retry: + return self._command_retry + + @property + def active_database(self) -> Optional[SyncDatabase]: + return self._active_database + + @active_database.setter + def active_database(self, database: SyncDatabase) -> None: + old_active = self._active_database + self._active_database = database + + if old_active is not None and old_active is not database: + self._event_dispatcher.dispatch( + ActiveDatabaseChanged( + old_active, + self._active_database, + self, + **self._active_pubsub_kwargs, + ) + ) + + @property + def active_pubsub(self) -> Optional[PubSub]: + return self._active_pubsub + + @active_pubsub.setter + def active_pubsub(self, pubsub: PubSub) -> None: + self._active_pubsub = pubsub + + @property + def failover_strategy_executor(self) -> FailoverStrategyExecutor: + return self._failover_strategy_executor + + def execute_command(self, *args, **options): + def callback(): + response = self._active_database.client.execute_command(*args, **options) + self._register_command_execution(args) + return response + + return self._execute_with_failure_detection(callback, args) + + def execute_pipeline(self, command_stack: tuple): + def callback(): + with self._active_database.client.pipeline() as pipe: + for command, options in command_stack: + pipe.execute_command(*command, **options) + + response = pipe.execute() + self._register_command_execution(command_stack) + return response + + return self._execute_with_failure_detection(callback, command_stack) + + def execute_transaction( + self, transaction: Callable[[Pipeline], None], *watches, **options + ): + def callback(): + response = self._active_database.client.transaction( + transaction, *watches, **options + ) + self._register_command_execution(()) + return response + + return self._execute_with_failure_detection(callback) + + def pubsub(self, **kwargs): + def callback(): + if self._active_pubsub is None: + self._active_pubsub = self._active_database.client.pubsub(**kwargs) + self._active_pubsub_kwargs = kwargs + return None + + return self._execute_with_failure_detection(callback) + + def execute_pubsub_method(self, method_name: str, *args, **kwargs): + def callback(): + method = getattr(self.active_pubsub, method_name) + response = method(*args, **kwargs) + self._register_command_execution(args) + return response + + return self._execute_with_failure_detection(callback, *args) + + def execute_pubsub_run(self, sleep_time, **kwargs) -> "PubSubWorkerThread": + def callback(): + return self._active_pubsub.run_in_thread(sleep_time, **kwargs) + + return self._execute_with_failure_detection(callback) + + def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()): + """ + Execute a commands execution callback with failure detection. + """ + + def wrapper(): + # On each retry we need to check active database as it might change. + self._check_active_database() + return callback() + + return self._command_retry.call_with_retry( + lambda: wrapper(), + lambda error: self._on_command_fail(error, *cmds), + ) + + def _on_command_fail(self, error, *args): + self._event_dispatcher.dispatch(OnCommandsFailEvent(args, error)) + + def _check_active_database(self): + """ + Checks if active a database needs to be updated. + """ + if ( + self._active_database is None + or self._active_database.circuit.state != CBState.CLOSED + or ( + self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL + and self._next_fallback_attempt <= datetime.now() + ) + ): + self.active_database = self._failover_strategy_executor.execute() + self._schedule_next_fallback() + + def _register_command_execution(self, cmd: tuple): + for detector in self._failure_detectors: + detector.register_command_execution(cmd) + + def _setup_event_dispatcher(self): + """ + Registers necessary listeners. + """ + failure_listener = RegisterCommandFailure(self._failure_detectors) + resubscribe_listener = ResubscribeOnActiveDatabaseChanged() + close_connection_listener = CloseConnectionOnActiveDatabaseChanged() + self._event_dispatcher.register_listeners( + { + OnCommandsFailEvent: [failure_listener], + ActiveDatabaseChanged: [ + close_connection_listener, + resubscribe_listener, + ], + } + ) diff --git a/redis/multidb/config.py b/redis/multidb/config.py new file mode 100644 index 0000000000..4586263748 --- /dev/null +++ b/redis/multidb/config.py @@ -0,0 +1,207 @@ +from dataclasses import dataclass, field +from typing import List, Type, Union + +import pybreaker +from typing_extensions import Optional + +from redis import ConnectionPool, Redis, RedisCluster +from redis.backoff import ExponentialWithJitterBackoff, NoBackoff +from redis.data_structure import WeightedList +from redis.event import EventDispatcher, EventDispatcherInterface +from redis.multidb.circuit import ( + DEFAULT_GRACE_PERIOD, + CircuitBreaker, + PBCircuitBreakerAdapter, +) +from redis.multidb.database import Database, Databases +from redis.multidb.failover import ( + DEFAULT_FAILOVER_ATTEMPTS, + DEFAULT_FAILOVER_DELAY, + FailoverStrategy, + WeightBasedFailoverStrategy, +) +from redis.multidb.failure_detector import ( + DEFAULT_FAILURE_RATE_THRESHOLD, + DEFAULT_FAILURES_DETECTION_WINDOW, + DEFAULT_MIN_NUM_FAILURES, + CommandFailureDetector, + FailureDetector, +) +from redis.multidb.healthcheck import ( + DEFAULT_HEALTH_CHECK_DELAY, + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_HEALTH_CHECK_POLICY, + DEFAULT_HEALTH_CHECK_PROBES, + EchoHealthCheck, + HealthCheck, + HealthCheckPolicies, +) +from redis.retry import Retry + +DEFAULT_AUTO_FALLBACK_INTERVAL = 120 + + +def default_event_dispatcher() -> EventDispatcherInterface: + return EventDispatcher() + + +@dataclass +class DatabaseConfig: + """ + Dataclass representing the configuration for a database connection. + + This class is used to store configuration settings for a database connection, + including client options, connection sourcing details, circuit breaker settings, + and cluster-specific properties. It provides a structure for defining these + attributes and allows for the creation of customized configurations for various + database setups. + + Attributes: + weight (float): Weight of the database to define the active one. + client_kwargs (dict): Additional parameters for the database client connection. + from_url (Optional[str]): Redis URL way of connecting to the database. + from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. + circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. + grace_period (float): Grace period after which we need to check if the circuit could be closed again. + health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used + on public Redis Enterprise endpoints. + + Methods: + default_circuit_breaker: + Generates and returns a default CircuitBreaker instance adapted for use. + """ + + weight: float = 1.0 + client_kwargs: dict = field(default_factory=dict) + from_url: Optional[str] = None + from_pool: Optional[ConnectionPool] = None + circuit: Optional[CircuitBreaker] = None + grace_period: float = DEFAULT_GRACE_PERIOD + health_check_url: Optional[str] = None + + def default_circuit_breaker(self) -> CircuitBreaker: + circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) + return PBCircuitBreakerAdapter(circuit_breaker) + + +@dataclass +class MultiDbConfig: + """ + Configuration class for managing multiple database connections in a resilient and fail-safe manner. + + Attributes: + databases_config: A list of database configurations. + client_class: The client class used to manage database connections. + command_retry: Retry strategy for executing database commands. + failure_detectors: Optional list of additional failure detectors for monitoring database failures. + min_num_failures: Minimal count of failures required for failover + failure_rate_threshold: Percentage of failures required for failover + failures_detection_window: Time interval for tracking database failures. + health_checks: Optional list of additional health checks performed on databases. + health_check_interval: Time interval for executing health checks. + health_check_probes: Number of attempts to evaluate the health of a database. + health_check_delay: Delay between health check attempts. + health_check_policy: Policy for determining database health based on health checks. + failover_strategy: Optional strategy for handling database failover scenarios. + failover_attempts: Number of retries allowed for failover operations. + failover_delay: Delay between failover attempts. + auto_fallback_interval: Time interval to trigger automatic fallback. + event_dispatcher: Interface for dispatching events related to database operations. + + Methods: + databases: + Retrieves a collection of database clients managed by weighted configurations. + Initializes database clients based on the provided configuration and removes + redundant retry objects for lower-level clients to rely on global retry logic. + + default_failure_detectors: + Returns the default list of failure detectors used to monitor database failures. + + default_health_checks: + Returns the default list of health checks used to monitor database health + with specific retry and backoff strategies. + + default_failover_strategy: + Provides the default failover strategy used for handling failover scenarios + with defined retry and backoff configurations. + """ + + databases_config: List[DatabaseConfig] + client_class: Type[Union[Redis, RedisCluster]] = Redis + command_retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ) + failure_detectors: Optional[List[FailureDetector]] = None + min_num_failures: int = DEFAULT_MIN_NUM_FAILURES + failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD + failures_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW + health_checks: Optional[List[HealthCheck]] = None + health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL + health_check_probes: int = DEFAULT_HEALTH_CHECK_PROBES + health_check_delay: float = DEFAULT_HEALTH_CHECK_DELAY + health_check_policy: HealthCheckPolicies = DEFAULT_HEALTH_CHECK_POLICY + failover_strategy: Optional[FailoverStrategy] = None + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS + failover_delay: float = DEFAULT_FAILOVER_DELAY + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL + event_dispatcher: EventDispatcherInterface = field( + default_factory=default_event_dispatcher + ) + + def databases(self) -> Databases: + databases = WeightedList() + + for database_config in self.databases_config: + # The retry object is not used in the lower level clients, so we can safely remove it. + # We rely on command_retry in terms of global retries. + database_config.client_kwargs.update( + {"retry": Retry(retries=0, backoff=NoBackoff())} + ) + + if database_config.from_url: + client = self.client_class.from_url( + database_config.from_url, **database_config.client_kwargs + ) + elif database_config.from_pool: + database_config.from_pool.set_retry( + Retry(retries=0, backoff=NoBackoff()) + ) + client = self.client_class.from_pool( + connection_pool=database_config.from_pool + ) + else: + client = self.client_class(**database_config.client_kwargs) + + circuit = ( + database_config.default_circuit_breaker() + if database_config.circuit is None + else database_config.circuit + ) + databases.add( + Database( + client=client, + circuit=circuit, + weight=database_config.weight, + health_check_url=database_config.health_check_url, + ), + database_config.weight, + ) + + return databases + + def default_failure_detectors(self) -> List[FailureDetector]: + return [ + CommandFailureDetector( + min_num_failures=self.min_num_failures, + failure_rate_threshold=self.failure_rate_threshold, + failure_detection_window=self.failures_detection_window, + ), + ] + + def default_health_checks(self) -> List[HealthCheck]: + return [ + EchoHealthCheck(), + ] + + def default_failover_strategy(self) -> FailoverStrategy: + return WeightBasedFailoverStrategy() diff --git a/redis/multidb/database.py b/redis/multidb/database.py new file mode 100644 index 0000000000..d46de99e2d --- /dev/null +++ b/redis/multidb/database.py @@ -0,0 +1,130 @@ +from abc import ABC, abstractmethod +from typing import Optional, Union + +import redis +from redis import RedisCluster +from redis.data_structure import WeightedList +from redis.multidb.circuit import CircuitBreaker +from redis.typing import Number + + +class AbstractDatabase(ABC): + @property + @abstractmethod + def weight(self) -> float: + """The weight of this database in compare to others. Used to determine the database failover to.""" + pass + + @weight.setter + @abstractmethod + def weight(self, weight: float): + """Set the weight of this database in compare to others.""" + pass + + @property + @abstractmethod + def health_check_url(self) -> Optional[str]: + """Health check URL associated with the current database.""" + pass + + @health_check_url.setter + @abstractmethod + def health_check_url(self, health_check_url: Optional[str]): + """Set the health check URL associated with the current database.""" + pass + + +class BaseDatabase(AbstractDatabase): + def __init__( + self, + weight: float, + health_check_url: Optional[str] = None, + ): + self._weight = weight + self._health_check_url = health_check_url + + @property + def weight(self) -> float: + return self._weight + + @weight.setter + def weight(self, weight: float): + self._weight = weight + + @property + def health_check_url(self) -> Optional[str]: + return self._health_check_url + + @health_check_url.setter + def health_check_url(self, health_check_url: Optional[str]): + self._health_check_url = health_check_url + + +class SyncDatabase(AbstractDatabase): + """Database with an underlying synchronous redis client.""" + + @property + @abstractmethod + def client(self) -> Union[redis.Redis, RedisCluster]: + """The underlying redis client.""" + pass + + @client.setter + @abstractmethod + def client(self, client: Union[redis.Redis, RedisCluster]): + """Set the underlying redis client.""" + pass + + @property + @abstractmethod + def circuit(self) -> CircuitBreaker: + """Circuit breaker for the current database.""" + pass + + @circuit.setter + @abstractmethod + def circuit(self, circuit: CircuitBreaker): + """Set the circuit breaker for the current database.""" + pass + + +Databases = WeightedList[tuple[SyncDatabase, Number]] + + +class Database(BaseDatabase, SyncDatabase): + def __init__( + self, + client: Union[redis.Redis, RedisCluster], + circuit: CircuitBreaker, + weight: float, + health_check_url: Optional[str] = None, + ): + """ + Initialize a new Database instance. + + Args: + client: Underlying Redis client instance for database operations + circuit: Circuit breaker for handling database failures + weight: Weight value used for database failover prioritization + health_check_url: Health check URL associated with the current database + """ + self._client = client + self._cb = circuit + self._cb.database = self + super().__init__(weight, health_check_url) + + @property + def client(self) -> Union[redis.Redis, RedisCluster]: + return self._client + + @client.setter + def client(self, client: Union[redis.Redis, RedisCluster]): + self._client = client + + @property + def circuit(self) -> CircuitBreaker: + return self._cb + + @circuit.setter + def circuit(self, circuit: CircuitBreaker): + self._cb = circuit diff --git a/redis/multidb/event.py b/redis/multidb/event.py new file mode 100644 index 0000000000..0ffeb7f66e --- /dev/null +++ b/redis/multidb/event.py @@ -0,0 +1,89 @@ +from typing import List + +from redis.client import Redis +from redis.event import EventListenerInterface, OnCommandsFailEvent +from redis.multidb.database import SyncDatabase +from redis.multidb.failure_detector import FailureDetector + + +class ActiveDatabaseChanged: + """ + Event fired when an active database has been changed. + """ + + def __init__( + self, + old_database: SyncDatabase, + new_database: SyncDatabase, + command_executor, + **kwargs, + ): + self._old_database = old_database + self._new_database = new_database + self._command_executor = command_executor + self._kwargs = kwargs + + @property + def old_database(self) -> SyncDatabase: + return self._old_database + + @property + def new_database(self) -> SyncDatabase: + return self._new_database + + @property + def command_executor(self): + return self._command_executor + + @property + def kwargs(self): + return self._kwargs + + +class ResubscribeOnActiveDatabaseChanged(EventListenerInterface): + """ + Re-subscribe the currently active pub / sub to a new active database. + """ + + def listen(self, event: ActiveDatabaseChanged): + old_pubsub = event.command_executor.active_pubsub + + if old_pubsub is not None: + # Re-assign old channels and patterns so they will be automatically subscribed on connection. + new_pubsub = event.new_database.client.pubsub(**event.kwargs) + new_pubsub.channels = old_pubsub.channels + new_pubsub.patterns = old_pubsub.patterns + new_pubsub.shard_channels = old_pubsub.shard_channels + new_pubsub.on_connect(None) + event.command_executor.active_pubsub = new_pubsub + old_pubsub.close() + + +class CloseConnectionOnActiveDatabaseChanged(EventListenerInterface): + """ + Close connection to the old active database. + """ + + def listen(self, event: ActiveDatabaseChanged): + event.old_database.client.close() + + if isinstance(event.old_database.client, Redis): + event.old_database.client.connection_pool.update_active_connections_for_reconnect() + event.old_database.client.connection_pool.disconnect() + else: + for node in event.old_database.client.nodes_manager.nodes_cache.values(): + node.redis_connection.connection_pool.update_active_connections_for_reconnect() + node.redis_connection.connection_pool.disconnect() + + +class RegisterCommandFailure(EventListenerInterface): + """ + Event listener that registers command failures and passing it to the failure detectors. + """ + + def __init__(self, failure_detectors: List[FailureDetector]): + self._failure_detectors = failure_detectors + + def listen(self, event: OnCommandsFailEvent) -> None: + for failure_detector in self._failure_detectors: + failure_detector.register_failure(event.exception, event.commands) diff --git a/redis/multidb/exception.py b/redis/multidb/exception.py new file mode 100644 index 0000000000..8c08c4b540 --- /dev/null +++ b/redis/multidb/exception.py @@ -0,0 +1,17 @@ +class NoValidDatabaseException(Exception): + pass + + +class UnhealthyDatabaseException(Exception): + """Exception raised when a database is unhealthy due to an underlying exception.""" + + def __init__(self, message, database, original_exception): + super().__init__(message) + self.database = database + self.original_exception = original_exception + + +class TemporaryUnavailableException(Exception): + """Exception raised when all databases in setup are temporary unavailable.""" + + pass diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py new file mode 100644 index 0000000000..c660eddbd3 --- /dev/null +++ b/redis/multidb/failover.py @@ -0,0 +1,125 @@ +import time +from abc import ABC, abstractmethod + +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState +from redis.multidb.database import Databases, SyncDatabase +from redis.multidb.exception import ( + NoValidDatabaseException, + TemporaryUnavailableException, +) + +DEFAULT_FAILOVER_ATTEMPTS = 10 +DEFAULT_FAILOVER_DELAY = 12 + + +class FailoverStrategy(ABC): + @abstractmethod + def database(self) -> SyncDatabase: + """Select the database according to the strategy.""" + pass + + @abstractmethod + def set_databases(self, databases: Databases) -> None: + """Set the database strategy operates on.""" + pass + + +class FailoverStrategyExecutor(ABC): + @property + @abstractmethod + def failover_attempts(self) -> int: + """The number of failover attempts.""" + pass + + @property + @abstractmethod + def failover_delay(self) -> float: + """The delay between failover attempts.""" + pass + + @property + @abstractmethod + def strategy(self) -> FailoverStrategy: + """The strategy to execute.""" + pass + + @abstractmethod + def execute(self) -> SyncDatabase: + """Execute the failover strategy.""" + pass + + +class WeightBasedFailoverStrategy(FailoverStrategy): + """ + Failover strategy based on database weights. + """ + + def __init__(self) -> None: + self._databases = WeightedList() + + def database(self) -> SyncDatabase: + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + return database + + raise NoValidDatabaseException("No valid database available for communication") + + def set_databases(self, databases: Databases) -> None: + self._databases = databases + + +class DefaultFailoverStrategyExecutor(FailoverStrategyExecutor): + """ + Executes given failover strategy. + """ + + def __init__( + self, + strategy: FailoverStrategy, + failover_attempts: int = DEFAULT_FAILOVER_ATTEMPTS, + failover_delay: float = DEFAULT_FAILOVER_DELAY, + ): + self._strategy = strategy + self._failover_attempts = failover_attempts + self._failover_delay = failover_delay + self._next_attempt_ts: int = 0 + self._failover_counter: int = 0 + + @property + def failover_attempts(self) -> int: + return self._failover_attempts + + @property + def failover_delay(self) -> float: + return self._failover_delay + + @property + def strategy(self) -> FailoverStrategy: + return self._strategy + + def execute(self) -> SyncDatabase: + try: + database = self._strategy.database() + self._reset() + return database + except NoValidDatabaseException as e: + if self._next_attempt_ts == 0: + self._next_attempt_ts = time.time() + self._failover_delay + self._failover_counter += 1 + elif time.time() >= self._next_attempt_ts: + self._next_attempt_ts += self._failover_delay + self._failover_counter += 1 + + if self._failover_counter > self._failover_attempts: + self._reset() + raise e + else: + raise TemporaryUnavailableException( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + + def _reset(self) -> None: + self._next_attempt_ts = 0 + self._failover_counter = 0 diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py new file mode 100644 index 0000000000..f1be28788e --- /dev/null +++ b/redis/multidb/failure_detector.py @@ -0,0 +1,104 @@ +import math +import threading +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import List, Type + +from typing_extensions import Optional + +from redis.multidb.circuit import State as CBState + +DEFAULT_MIN_NUM_FAILURES = 1000 +DEFAULT_FAILURE_RATE_THRESHOLD = 0.1 +DEFAULT_FAILURES_DETECTION_WINDOW = 2 + + +class FailureDetector(ABC): + @abstractmethod + def register_failure(self, exception: Exception, cmd: tuple) -> None: + """Register a failure that occurred during command execution.""" + pass + + @abstractmethod + def register_command_execution(self, cmd: tuple) -> None: + """Register a command execution.""" + pass + + @abstractmethod + def set_command_executor(self, command_executor) -> None: + """Set the command executor for this failure.""" + pass + + +class CommandFailureDetector(FailureDetector): + """ + Detects a failure based on a threshold of failed commands during a specific period of time. + """ + + def __init__( + self, + min_num_failures: int = DEFAULT_MIN_NUM_FAILURES, + failure_rate_threshold: float = DEFAULT_FAILURE_RATE_THRESHOLD, + failure_detection_window: float = DEFAULT_FAILURES_DETECTION_WINDOW, + error_types: Optional[List[Type[Exception]]] = None, + ) -> None: + """ + Initialize a new CommandFailureDetector instance. + + Args: + min_num_failures: Minimal count of failures required for failover + failure_rate_threshold: Percentage of failures required for failover + failure_detection_window: Time interval for executing health checks. + error_types: Optional list of exception types to trigger failover. If None, all exceptions are counted. + + The detector tracks command failures within a sliding time window. When the number of failures + exceeds the threshold within the specified duration, it triggers failure detection. + """ + self._command_executor = None + self._min_num_failures = min_num_failures + self._failure_rate_threshold = failure_rate_threshold + self._failure_detection_window = failure_detection_window + self._error_types = error_types + self._commands_executed: int = 0 + self._start_time: datetime = datetime.now() + self._end_time: datetime = self._start_time + timedelta( + seconds=self._failure_detection_window + ) + self._failures_count: int = 0 + self._lock = threading.RLock() + + def register_failure(self, exception: Exception, cmd: tuple) -> None: + with self._lock: + if self._error_types: + if type(exception) in self._error_types: + self._failures_count += 1 + else: + self._failures_count += 1 + + self._check_threshold() + + def set_command_executor(self, command_executor) -> None: + self._command_executor = command_executor + + def register_command_execution(self, cmd: tuple) -> None: + with self._lock: + if not self._start_time < datetime.now() < self._end_time: + self._reset() + + self._commands_executed += 1 + + def _check_threshold(self): + if self._failures_count >= self._min_num_failures and self._failures_count >= ( + math.ceil(self._commands_executed * self._failure_rate_threshold) + ): + self._command_executor.active_database.circuit.state = CBState.OPEN + self._reset() + + def _reset(self) -> None: + with self._lock: + self._start_time = datetime.now() + self._end_time = self._start_time + timedelta( + seconds=self._failure_detection_window + ) + self._failures_count = 0 + self._commands_executed = 0 diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py new file mode 100644 index 0000000000..5deda82f24 --- /dev/null +++ b/redis/multidb/healthcheck.py @@ -0,0 +1,289 @@ +import logging +from abc import ABC, abstractmethod +from enum import Enum +from time import sleep +from typing import List, Optional, Tuple, Union + +from redis import Redis +from redis.backoff import NoBackoff +from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient +from redis.multidb.exception import UnhealthyDatabaseException +from redis.retry import Retry + +DEFAULT_HEALTH_CHECK_PROBES = 3 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 +DEFAULT_HEALTH_CHECK_DELAY = 0.5 +DEFAULT_LAG_AWARE_TOLERANCE = 5000 + +logger = logging.getLogger(__name__) + + +class HealthCheck(ABC): + @abstractmethod + def check_health(self, database) -> bool: + """Function to determine the health status.""" + pass + + +class HealthCheckPolicy(ABC): + """ + Health checks execution policy. + """ + + @property + @abstractmethod + def health_check_probes(self) -> int: + """Number of probes to execute health checks.""" + pass + + @property + @abstractmethod + def health_check_delay(self) -> float: + """Delay between health check probes.""" + pass + + @abstractmethod + def execute(self, health_checks: List[HealthCheck], database) -> bool: + """Execute health checks and return database health status.""" + pass + + +class AbstractHealthCheckPolicy(HealthCheckPolicy): + def __init__(self, health_check_probes: int, health_check_delay: float): + if health_check_probes < 1: + raise ValueError("health_check_probes must be greater than 0") + self._health_check_probes = health_check_probes + self._health_check_delay = health_check_delay + + @property + def health_check_probes(self) -> int: + return self._health_check_probes + + @property + def health_check_delay(self) -> float: + return self._health_check_delay + + @abstractmethod + def execute(self, health_checks: List[HealthCheck], database) -> bool: + pass + + +class HealthyAllPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if all health check probes are successful. + """ + + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + def execute(self, health_checks: List[HealthCheck], database) -> bool: + for health_check in health_checks: + for attempt in range(self.health_check_probes): + try: + if not health_check.check_health(database): + return False + except Exception as e: + raise UnhealthyDatabaseException("Unhealthy database", database, e) + + if attempt < self.health_check_probes - 1: + sleep(self._health_check_delay) + return True + + +class HealthyMajorityPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if a majority of health check probes are successful. + """ + + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + def execute(self, health_checks: List[HealthCheck], database) -> bool: + for health_check in health_checks: + if self.health_check_probes % 2 == 0: + allowed_unsuccessful_probes = self.health_check_probes / 2 + else: + allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2 + + for attempt in range(self.health_check_probes): + try: + if not health_check.check_health(database): + allowed_unsuccessful_probes -= 1 + if allowed_unsuccessful_probes <= 0: + return False + except Exception as e: + allowed_unsuccessful_probes -= 1 + if allowed_unsuccessful_probes <= 0: + raise UnhealthyDatabaseException( + "Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + sleep(self._health_check_delay) + return True + + +class HealthyAnyPolicy(AbstractHealthCheckPolicy): + """ + Policy that returns True if at least one health check probe is successful. + """ + + def __init__(self, health_check_probes: int, health_check_delay: float): + super().__init__(health_check_probes, health_check_delay) + + def execute(self, health_checks: List[HealthCheck], database) -> bool: + is_healthy = False + + for health_check in health_checks: + exception = None + + for attempt in range(self.health_check_probes): + try: + if health_check.check_health(database): + is_healthy = True + break + else: + is_healthy = False + except Exception as e: + exception = UnhealthyDatabaseException( + "Unhealthy database", database, e + ) + + if attempt < self.health_check_probes - 1: + sleep(self._health_check_delay) + + if not is_healthy and not exception: + return is_healthy + elif not is_healthy and exception: + raise exception + + return is_healthy + + +class HealthCheckPolicies(Enum): + HEALTHY_ALL = HealthyAllPolicy + HEALTHY_MAJORITY = HealthyMajorityPolicy + HEALTHY_ANY = HealthyAnyPolicy + + +DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL + + +class EchoHealthCheck(HealthCheck): + """ + Health check based on ECHO command. + """ + + def check_health(self, database) -> bool: + expected_message = ["healthcheck", b"healthcheck"] + + if isinstance(database.client, Redis): + actual_message = database.client.execute_command("ECHO", "healthcheck") + return actual_message in expected_message + else: + # For a cluster checks if all nodes are healthy. + all_nodes = database.client.get_nodes() + for node in all_nodes: + actual_message = node.redis_connection.execute_command( + "ECHO", "healthcheck" + ) + + if actual_message not in expected_message: + return False + + return True + + +class LagAwareHealthCheck(HealthCheck): + """ + Health check available for Redis Enterprise deployments. + Verify via REST API that the database is healthy based on different lags. + """ + + def __init__( + self, + rest_api_port: int = 9443, + lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE, + timeout: float = DEFAULT_TIMEOUT, + auth_basic: Optional[Tuple[str, str]] = None, + verify_tls: bool = True, + # TLS verification (server) options + ca_file: Optional[str] = None, + ca_path: Optional[str] = None, + ca_data: Optional[Union[str, bytes]] = None, + # Mutual TLS (client cert) options + client_cert_file: Optional[str] = None, + client_key_file: Optional[str] = None, + client_key_password: Optional[str] = None, + ): + """ + Initialize LagAwareHealthCheck with the specified parameters. + + Args: + rest_api_port: Port number for Redis Enterprise REST API (default: 9443) + lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100) + timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) + auth_basic: Tuple of (username, password) for basic authentication + verify_tls: Whether to verify TLS certificates (default: True) + ca_file: Path to CA certificate file for TLS verification + ca_path: Path to CA certificates directory for TLS verification + ca_data: CA certificate data as string or bytes + client_cert_file: Path to client certificate file for mutual TLS + client_key_file: Path to client private key file for mutual TLS + client_key_password: Password for encrypted client private key + """ + self._http_client = HttpClient( + timeout=timeout, + auth_basic=auth_basic, + retry=Retry(NoBackoff(), retries=0), + verify_tls=verify_tls, + ca_file=ca_file, + ca_path=ca_path, + ca_data=ca_data, + client_cert_file=client_cert_file, + client_key_file=client_key_file, + client_key_password=client_key_password, + ) + self._rest_api_port = rest_api_port + self._lag_aware_tolerance = lag_aware_tolerance + + def check_health(self, database) -> bool: + if database.health_check_url is None: + raise ValueError( + "Database health check url is not set. Please check DatabaseConfig for the current database." + ) + + if isinstance(database.client, Redis): + db_host = database.client.get_connection_kwargs()["host"] + else: + db_host = database.client.startup_nodes[0].host + + base_url = f"{database.health_check_url}:{self._rest_api_port}" + self._http_client.base_url = base_url + + # Find bdb matching to the current database host + matching_bdb = None + for bdb in self._http_client.get("/v1/bdbs"): + for endpoint in bdb["endpoints"]: + if endpoint["dns_name"] == db_host: + matching_bdb = bdb + break + + # In case if the host was set as public IP + for addr in endpoint["addr"]: + if addr == db_host: + matching_bdb = bdb + break + + if matching_bdb is None: + logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") + raise ValueError("Could not find a matching bdb") + + url = ( + f"/v1/bdbs/{matching_bdb['uid']}/availability" + f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}" + ) + self._http_client.get(url, expect_json=False) + + # Status checked in an http client, otherwise HttpError will be raised + return True diff --git a/redis/retry.py b/redis/retry.py index 75778635e8..225e431eb2 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,7 +1,17 @@ import abc import socket from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Tuple, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Iterable, + Optional, + Tuple, + Type, + TypeVar, +) from redis.exceptions import ConnectionError, TimeoutError @@ -91,6 +101,7 @@ def call_with_retry( self, do: Callable[[], T], fail: Callable[[Exception], Any], + is_retryable: Optional[Callable[[Exception], bool]] = None, ) -> T: """ Execute an operation that might fail and returns its result, or @@ -104,6 +115,8 @@ def call_with_retry( try: return do() except self._supported_errors as error: + if is_retryable and not is_retryable(error): + raise failures += 1 fail(error) if self._retries >= 0 and failures > self._retries: diff --git a/redis/utils.py b/redis/utils.py index 79c23c8bda..5ae8fb25fc 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -312,3 +312,17 @@ def truncate_text(txt, max_length=100): return textwrap.shorten( text=txt, width=max_length, placeholder="...", break_long_words=True ) + + +def dummy_fail(): + """ + Fake function for a Retry object if you don't need to handle each failure. + """ + pass + + +async def dummy_fail_async(): + """ + Async fake function for a Retry object if you don't need to handle each failure. + """ + pass diff --git a/tasks.py b/tasks.py index 20f9f245aa..d63bd8c92d 100644 --- a/tasks.py +++ b/tasks.py @@ -58,11 +58,11 @@ def standalone_tests( if uvloop: run( - f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --ignore=tests/test_scenario --cov=./ --cov-report=xml:coverage_resp{protocol}_uvloop.xml -m 'not onlycluster{extra_markers}' --uvloop --junit-xml=standalone-resp{protocol}-uvloop-results.xml" + f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --ignore=tests/test_scenario --ignore=tests/test_asyncio/test_scenario --cov=./ --cov-report=xml:coverage_resp{protocol}_uvloop.xml -m 'not onlycluster{extra_markers}' --uvloop --junit-xml=standalone-resp{protocol}-uvloop-results.xml" ) else: run( - f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --ignore=tests/test_scenario --cov=./ --cov-report=xml:coverage_resp{protocol}.xml -m 'not onlycluster{extra_markers}' --junit-xml=standalone-resp{protocol}-results.xml" + f"pytest {profile_arg} --protocol={protocol} {redis_mod_url} --ignore=tests/test_scenario --ignore=tests/test_asyncio/test_scenario --cov=./ --cov-report=xml:coverage_resp{protocol}.xml -m 'not onlycluster{extra_markers}' --junit-xml=standalone-resp{protocol}-results.xml" ) @@ -74,11 +74,11 @@ def cluster_tests(c, uvloop=False, protocol=2, profile=False): cluster_tls_url = "rediss://localhost:27379/0" if uvloop: run( - f"pytest {profile_arg} --protocol={protocol} --ignore=tests/test_scenario --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}_uvloop.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-uvloop-results.xml --uvloop" + f"pytest {profile_arg} --protocol={protocol} --ignore=tests/test_scenario --ignore=tests/test_asyncio/test_scenario --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}_uvloop.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-uvloop-results.xml --uvloop" ) else: run( - f"pytest {profile_arg} --protocol={protocol} --ignore=tests/test_scenario --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-results.xml" + f"pytest {profile_arg} --protocol={protocol} --ignore=tests/test_scenario --ignore=tests/test_asyncio/test_scenario --cov=./ --cov-report=xml:coverage_cluster_resp{protocol}.xml -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --redis-ssl-url={cluster_tls_url} --junit-xml=cluster-resp{protocol}-results.xml" ) diff --git a/tests/conftest.py b/tests/conftest.py index 7eaccb1acb..af2681732b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,7 @@ ) from redis.connection import Connection, ConnectionInterface, SSLConnection, parse_url from redis.credentials import CredentialProvider +from redis.event import EventDispatcherInterface from redis.exceptions import RedisClusterException from redis.retry import Retry from tests.ssl_utils import get_tls_certificates @@ -582,6 +583,12 @@ def mock_connection() -> ConnectionInterface: return mock_connection +@pytest.fixture() +def mock_ed() -> EventDispatcherInterface: + mock_ed = Mock(spec=EventDispatcherInterface) + return mock_ed + + @pytest.fixture() def cache_key(request) -> CacheKey: command = request.param.get("command") diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index c30220fb1d..cb3dac9604 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -114,6 +114,9 @@ def set_re_auth_token(self, token: TokenInterface): async def re_auth(self): pass + def should_reconnect(self): + return False + class TestConnectionPool: @asynccontextmanager diff --git a/tests/test_asyncio/test_multidb/__init__.py b/tests/test_asyncio/test_multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py new file mode 100644 index 0000000000..0666dc527a --- /dev/null +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -0,0 +1,131 @@ +from unittest.mock import Mock + +import pytest + +from redis.asyncio.multidb.config import ( + MultiDbConfig, + DatabaseConfig, + DEFAULT_AUTO_FALLBACK_INTERVAL, +) +from redis.asyncio.multidb.failover import AsyncFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import ( + HealthCheck, + DEFAULT_HEALTH_CHECK_PROBES, + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_HEALTH_CHECK_POLICY, +) +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState, CircuitBreaker +from redis.asyncio import Redis, ConnectionPool +from redis.asyncio.multidb.database import Database, Databases + + +@pytest.fixture() +def mock_client() -> Redis: + return Mock(spec=Redis) + + +@pytest.fixture() +def mock_cb() -> CircuitBreaker: + return Mock(spec=CircuitBreaker) + + +@pytest.fixture() +def mock_fd() -> AsyncFailureDetector: + return Mock(spec=AsyncFailureDetector) + + +@pytest.fixture() +def mock_fs() -> AsyncFailoverStrategy: + return Mock(spec=AsyncFailoverStrategy) + + +@pytest.fixture() +def mock_hc() -> HealthCheck: + return Mock(spec=HealthCheck) + + +@pytest.fixture() +def mock_db(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + + +@pytest.fixture() +def mock_db1(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + + +@pytest.fixture() +def mock_db2(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + + +@pytest.fixture() +def mock_multi_db_config(request, mock_fd, mock_fs, mock_hc, mock_ed) -> MultiDbConfig: + hc_interval = request.param.get("hc_interval", DEFAULT_HEALTH_CHECK_INTERVAL) + auto_fallback_interval = request.param.get( + "auto_fallback_interval", DEFAULT_AUTO_FALLBACK_INTERVAL + ) + health_check_policy = request.param.get( + "health_check_policy", DEFAULT_HEALTH_CHECK_POLICY + ) + health_check_probes = request.param.get( + "health_check_probes", DEFAULT_HEALTH_CHECK_PROBES + ) + + config = MultiDbConfig( + databases_config=[Mock(spec=DatabaseConfig)], + failure_detectors=[mock_fd], + health_check_interval=hc_interval, + health_check_delay=0.05, + health_check_policy=health_check_policy, + health_check_probes=health_check_probes, + failover_strategy=mock_fs, + auto_fallback_interval=auto_fallback_interval, + event_dispatcher=mock_ed, + ) + + return config + + +def create_weighted_list(*databases) -> Databases: + dbs = WeightedList() + + for db in databases: + dbs.add(db, db.weight) + + return dbs diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py new file mode 100644 index 0000000000..230a01d64d --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -0,0 +1,575 @@ +import asyncio +from unittest.mock import patch, AsyncMock, Mock + +import pybreaker +import pytest + +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.database import AsyncDatabase +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, HealthCheck +from redis.event import EventDispatcher, AsyncOnCommandsFailEvent +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.exception import NoValidDatabaseException +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +@pytest.mark.onlynoncluster +class TestMultiDbClient: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_on_successful_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command = AsyncMock(return_value="OK1") + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 9 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.CLOSED + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command = AsyncMock(return_value="OK1") + + mock_hc.check_health.side_effect = [ + False, + True, + True, + True, + True, + True, + True, + ] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 7 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "OK", + "error", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "OK1", + "error", + "error", + "healthcheck", + "OK1", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "OK2", + "error", + "error", + ] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() + + client = MultiDBClient(mock_multi_db_config) + assert await client.set("key", "value") == "OK1" + await asyncio.sleep(0.15) + assert await client.set("key", "value") == "OK2" + await asyncio.sleep(0.1) + assert await client.set("key", "value") == "OK" + await asyncio.sleep(0.1) + assert await client.set("key", "value") == "OK1" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_auto_fallback_to_highest_weight_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_hc.check_health.side_effect = [ + True, + True, + True, + False, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + ] + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[mock_hc], + ), + ): + mock_db.client.execute_command.return_value = "OK" + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.auto_fallback_interval = 0.5 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() + + client = MultiDBClient(mock_multi_db_config) + assert await client.set("key", "value") == "OK1" + await asyncio.sleep(0.15) + assert await client.set("key", "value") == "OK2" + await asyncio.sleep(0.5) + assert await client.set("key", "value") == "OK1" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.5, "circuit": {"state": CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_command_throws_exception_on_failed_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises( + NoValidDatabaseException, + match="Initial connection failed - no active database found", + ): + await client.set("key", "value") + assert mock_hc.check_health.call_count == 9 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_database_throws_exception_on_same_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(ValueError, match="Given database already exists"): + await client.add_database(mock_db) + assert mock_hc.check_health.call_count == 9 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_database_makes_new_database_active( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set("key", "value") == "OK2" + assert mock_hc.check_health.call_count == 6 + + await client.add_database(mock_db1) + assert mock_hc.check_health.call_count == 9 + + assert await client.set("key", "value") == "OK1" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_remove_highest_weighted_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 9 + + await client.remove_database(mock_db1) + assert await client.set("key", "value") == "OK2" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_update_database_weight_to_be_highest( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 9 + + await client.update_database_weight(mock_db2, 0.8) + assert mock_db2.weight == 0.8 + + assert await client.set("key", "value") == "OK2" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_new_failure_detector( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_multi_db_config.event_dispatcher = EventDispatcher() + mock_fd = mock_multi_db_config.failure_detectors[0] + + # Event fired if command against mock_db1 would fail + command_fail_event = AsyncOnCommandsFailEvent( + commands=("SET", "key", "value"), + exception=Exception(), + ) + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 9 + + # Simulate failing command events that lead to a failure detection + for i in range(5): + await mock_multi_db_config.event_dispatcher.dispatch_async( + command_fail_event + ) + + assert mock_fd.register_failure.call_count == 5 + + another_fd = Mock(spec=AsyncFailureDetector) + client.add_failure_detector(another_fd) + + # Simulate failing command events that lead to a failure detection + for i in range(5): + await mock_multi_db_config.event_dispatcher.dispatch_async( + command_fail_event + ) + + assert mock_fd.register_failure.call_count == 10 + assert another_fd.register_failure.call_count == 5 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_new_health_check( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 9 + + another_hc = Mock(spec=HealthCheck) + another_hc.check_health.return_value = True + + await client.add_health_check(another_hc) + await client._check_db_health(mock_db1) + + assert mock_hc.check_health.call_count == 12 + assert another_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_set_active_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db.client.execute_command.return_value = "OK" + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 9 + + await client.set_active_database(mock_db) + assert await client.set("key", "value") == "OK" + + with pytest.raises( + ValueError, match="Given database is not a member of database list" + ): + await client.set_active_database(Mock(spec=AsyncDatabase)) + + mock_hc.check_health.return_value = False + + with pytest.raises( + NoValidDatabaseException, + match="Cannot set active database, database is unhealthy", + ): + await client.set_active_database(mock_db1) diff --git a/tests/test_asyncio/test_multidb/test_command_executor.py b/tests/test_asyncio/test_multidb/test_command_executor.py new file mode 100644 index 0000000000..e0ac80a56a --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_command_executor.py @@ -0,0 +1,181 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper +from redis.event import EventDispatcher +from redis.exceptions import ConnectionError +from redis.asyncio.multidb.command_executor import DefaultCommandExecutor +from redis.asyncio.retry import Retry +from redis.backoff import NoBackoff +from redis.multidb.circuit import State as CBState +from redis.multidb.failure_detector import CommandFailureDetector +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +@pytest.mark.onlynoncluster +class TestDefaultCommandExecutor: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_db,mock_db1,mock_db2", + [ + ( + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_on_active_database( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command = AsyncMock(return_value="OK1") + mock_db2.client.execute_command = AsyncMock(return_value="OK2") + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0), + ) + + await executor.set_active_database(mock_db1) + assert await executor.execute_command("SET", "key", "value") == "OK1" + + await executor.set_active_database(mock_db2) + assert await executor.execute_command("SET", "key", "value") == "OK2" + assert mock_ed.register_listeners.call_count == 1 + assert mock_fd.register_command_execution.call_count == 2 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_db,mock_db1,mock_db2", + [ + ( + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_automatically_select_active_database( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command = AsyncMock(return_value="OK1") + mock_db2.client.execute_command = AsyncMock(return_value="OK2") + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0), + ) + + assert await executor.execute_command("SET", "key", "value") == "OK1" + mock_db1.circuit.state = CBState.OPEN + + assert await executor.execute_command("SET", "key", "value") == "OK2" + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 2 + assert mock_fd.register_command_execution.call_count == 2 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_db,mock_db1,mock_db2", + [ + ( + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_fallback_to_another_db_after_fallback_interval( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command = AsyncMock(return_value="OK1") + mock_db2.client.execute_command = AsyncMock(return_value="OK2") + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), 0), + ) + + assert await executor.execute_command("SET", "key", "value") == "OK1" + mock_db1.weight = 0.1 + await asyncio.sleep(0.15) + + assert await executor.execute_command("SET", "key", "value") == "OK2" + mock_db1.weight = 0.7 + await asyncio.sleep(0.15) + + assert await executor.execute_command("SET", "key", "value") == "OK1" + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 3 + assert mock_fd.register_command_execution.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_db,mock_db1,mock_db2", + [ + ( + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_fallback_to_another_db_after_failure_detection( + self, mock_db, mock_db1, mock_db2, mock_fs + ): + mock_db1.client.execute_command = AsyncMock( + side_effect=[ + "OK1", + ConnectionError, + ConnectionError, + ConnectionError, + "OK1", + ] + ) + mock_db2.client.execute_command = AsyncMock( + side_effect=["OK2", ConnectionError, ConnectionError, ConnectionError] + ) + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + threshold = 3 + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(threshold, 1)) + ed = EventDispatcher() + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), threshold), + ) + fd.set_command_executor(command_executor=executor) + + assert await executor.execute_command("SET", "key", "value") == "OK1" + assert await executor.execute_command("SET", "key", "value") == "OK2" + assert await executor.execute_command("SET", "key", "value") == "OK1" + assert mock_selector.call_count == 3 diff --git a/tests/test_asyncio/test_multidb/test_config.py b/tests/test_asyncio/test_multidb/test_config.py new file mode 100644 index 0000000000..d05c7a8a12 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_config.py @@ -0,0 +1,166 @@ +from unittest.mock import Mock + +import pytest + +from redis.asyncio import ConnectionPool +from redis.asyncio.multidb.config import ( + DatabaseConfig, + MultiDbConfig, + DEFAULT_GRACE_PERIOD, + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_AUTO_FALLBACK_INTERVAL, +) +from redis.asyncio.multidb.database import Database +from redis.asyncio.multidb.failover import ( + WeightBasedFailoverStrategy, + AsyncFailoverStrategy, +) +from redis.asyncio.multidb.failure_detector import ( + FailureDetectorAsyncWrapper, + AsyncFailureDetector, +) +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, HealthCheck +from redis.asyncio.retry import Retry +from redis.multidb.circuit import CircuitBreaker + + +@pytest.mark.onlynoncluster +class TestMultiDbConfig: + def test_default_config(self): + db_configs = [ + DatabaseConfig( + client_kwargs={"host": "host1", "port": "port1"}, weight=1.0 + ), + DatabaseConfig( + client_kwargs={"host": "host2", "port": "port2"}, weight=0.9 + ), + DatabaseConfig( + client_kwargs={"host": "host3", "port": "port3"}, weight=0.8 + ), + ] + + config = MultiDbConfig(databases_config=db_configs) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD + assert db.client.get_retry() is not config.command_retry + i += 1 + + assert len(config.default_failure_detectors()) == 1 + assert isinstance( + config.default_failure_detectors()[0], FailureDetectorAsyncWrapper + ) + assert len(config.default_health_checks()) == 1 + assert isinstance(config.default_health_checks()[0], EchoHealthCheck) + assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL + assert isinstance( + config.default_failover_strategy(), WeightBasedFailoverStrategy + ) + assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL + assert isinstance(config.command_retry, Retry) + + def test_overridden_config(self): + grace_period = 2 + mock_connection_pools = [ + Mock(spec=ConnectionPool), + Mock(spec=ConnectionPool), + Mock(spec=ConnectionPool), + ] + mock_connection_pools[0].connection_kwargs = {} + mock_connection_pools[1].connection_kwargs = {} + mock_connection_pools[2].connection_kwargs = {} + mock_cb1 = Mock(spec=CircuitBreaker) + mock_cb1.grace_period = grace_period + mock_cb2 = Mock(spec=CircuitBreaker) + mock_cb2.grace_period = grace_period + mock_cb3 = Mock(spec=CircuitBreaker) + mock_cb3.grace_period = grace_period + mock_failure_detectors = [ + Mock(spec=AsyncFailureDetector), + Mock(spec=AsyncFailureDetector), + ] + mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] + health_check_interval = 10 + mock_failover_strategy = Mock(spec=AsyncFailoverStrategy) + auto_fallback_interval = 10 + db_configs = [ + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[0]}, + weight=1.0, + circuit=mock_cb1, + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[1]}, + weight=0.9, + circuit=mock_cb2, + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[2]}, + weight=0.8, + circuit=mock_cb3, + ), + ] + + config = MultiDbConfig( + databases_config=db_configs, + failure_detectors=mock_failure_detectors, + health_checks=mock_health_checks, + health_check_interval=health_check_interval, + failover_strategy=mock_failover_strategy, + auto_fallback_interval=auto_fallback_interval, + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.client.connection_pool == mock_connection_pools[i] + assert db.circuit.grace_period == grace_period + i += 1 + + assert len(config.failure_detectors) == 2 + assert config.failure_detectors[0] == mock_failure_detectors[0] + assert config.failure_detectors[1] == mock_failure_detectors[1] + assert len(config.health_checks) == 2 + assert config.health_checks[0] == mock_health_checks[0] + assert config.health_checks[1] == mock_health_checks[1] + assert config.health_check_interval == health_check_interval + assert config.failover_strategy == mock_failover_strategy + assert config.auto_fallback_interval == auto_fallback_interval + + +@pytest.mark.onlynoncluster +class TestDatabaseConfig: + def test_default_config(self): + config = DatabaseConfig( + client_kwargs={"host": "host1", "port": "port1"}, weight=1.0 + ) + + assert config.client_kwargs == {"host": "host1", "port": "port1"} + assert config.weight == 1.0 + assert isinstance(config.default_circuit_breaker(), CircuitBreaker) + + def test_overridden_config(self): + mock_connection_pool = Mock(spec=ConnectionPool) + mock_circuit = Mock(spec=CircuitBreaker) + + config = DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pool}, + weight=1.0, + circuit=mock_circuit, + ) + + assert config.client_kwargs == {"connection_pool": mock_connection_pool} + assert config.weight == 1.0 + assert config.circuit == mock_circuit diff --git a/tests/test_asyncio/test_multidb/test_failover.py b/tests/test_asyncio/test_multidb/test_failover.py new file mode 100644 index 0000000000..a34bb368c8 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_failover.py @@ -0,0 +1,169 @@ +import asyncio + +import pytest + +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import ( + NoValidDatabaseException, + TemporaryUnavailableException, +) +from redis.asyncio.multidb.failover import ( + WeightBasedFailoverStrategy, + DefaultFailoverStrategyExecutor, +) + + +@pytest.mark.onlynoncluster +class TestAsyncWeightBasedFailoverStrategy: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_db,mock_db1,mock_db2", + [ + ( + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ( + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + ), + ], + ids=["all closed - highest weight", "highest weight - open"], + indirect=True, + ) + async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + + strategy = WeightBasedFailoverStrategy() + strategy.set_databases(databases) + + assert await strategy.database() == mock_db1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_db,mock_db1,mock_db2", + [ + ( + {"weight": 0.2, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.5, "circuit": {"state": CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_throws_exception_on_empty_databases( + self, mock_db, mock_db1, mock_db2 + ): + failover_strategy = WeightBasedFailoverStrategy() + + with pytest.raises( + NoValidDatabaseException, + match="No valid database available for communication", + ): + assert await failover_strategy.database() + + +class TestDefaultStrategyExecutor: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_db", + [ + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_execute_returns_valid_database_with_failover_attempts( + self, mock_db, mock_fs + ): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + mock_db, + ] + executor = DefaultFailoverStrategyExecutor( + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 + ) + + for i in range(failover_attempts + 1): + try: + database = await executor.execute() + assert database == mock_db + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + await asyncio.sleep(0.11) + pass + + assert mock_fs.database.call_count == 4 + + @pytest.mark.asyncio + async def test_execute_throws_exception_on_attempts_exceed(self, mock_fs): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + ] + executor = DefaultFailoverStrategyExecutor( + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 + ) + + with pytest.raises(NoValidDatabaseException): + for i in range(failover_attempts + 1): + try: + await executor.execute() + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + await asyncio.sleep(0.11) + pass + + assert mock_fs.database.call_count == 4 + + @pytest.mark.asyncio + async def test_execute_throws_exception_on_attempts_does_not_exceed_delay( + self, mock_fs + ): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + ] + executor = DefaultFailoverStrategyExecutor( + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 + ) + + with pytest.raises( + TemporaryUnavailableException, + match=( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ), + ): + for i in range(failover_attempts + 1): + try: + await executor.execute() + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + if i == failover_attempts: + raise e + + assert mock_fs.database.call_count == 4 diff --git a/tests/test_asyncio/test_multidb/test_failure_detector.py b/tests/test_asyncio/test_multidb/test_failure_detector.py new file mode 100644 index 0000000000..279bda9605 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_failure_detector.py @@ -0,0 +1,128 @@ +import asyncio +from unittest.mock import Mock + +import pytest + +from redis.asyncio.multidb.command_executor import AsyncCommandExecutor +from redis.asyncio.multidb.database import Database +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper +from redis.multidb.circuit import State as CBState +from redis.multidb.failure_detector import CommandFailureDetector + + +@pytest.mark.onlynoncluster +class TestFailureDetectorAsyncWrapper: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "min_num_failures,failure_rate_threshold,circuit_state", + [ + (2, 0.4, CBState.OPEN), + (2, 0, CBState.OPEN), + (0, 0.4, CBState.OPEN), + (3, 0.4, CBState.CLOSED), + (2, 0.41, CBState.CLOSED), + ], + ids=[ + "exceeds min num failures AND failures rate", + "exceeds min num failures AND failures rate == 0", + "min num failures == 0 AND exceeds failures rate", + "do not exceeds min num failures", + "do not exceeds failures rate", + ], + ) + async def test_failure_detector_correctly_reacts_to_failures( + self, min_num_failures, failure_rate_threshold, circuit_state + ): + fd = FailureDetectorAsyncWrapper( + CommandFailureDetector(min_num_failures, failure_rate_threshold) + ) + mock_db = Mock(spec=Database) + mock_db.circuit.state = CBState.CLOSED + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + + await fd.register_command_execution(("GET", "key")) + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) + + await fd.register_command_execution(("GET", "key")) + await fd.register_command_execution(("GET", "key")) + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) + + assert mock_db.circuit.state == circuit_state + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "min_num_failures,failure_rate_threshold", + [ + (3, 0.0), + (3, 0.6), + ], + ids=[ + "do not exceeds min num failures, during interval", + "do not exceeds min num failures AND failure rate, during interval", + ], + ) + async def test_failure_detector_do_not_open_circuit_on_interval_exceed( + self, min_num_failures, failure_rate_threshold + ): + fd = FailureDetectorAsyncWrapper( + CommandFailureDetector(min_num_failures, failure_rate_threshold, 0.3) + ) + mock_db = Mock(spec=Database) + mock_db.circuit.state = CBState.CLOSED + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) + await asyncio.sleep(0.16) + await fd.register_command_execution(("GET", "key")) + await fd.register_command_execution(("GET", "key")) + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) + await asyncio.sleep(0.16) + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) + + assert mock_db.circuit.state == CBState.CLOSED + + # 2 more failure as last one already refreshed timer + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) + await fd.register_command_execution(("GET", "key")) + await fd.register_failure(Exception(), ("GET", "key")) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + async def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed( + self, + ): + fd = FailureDetectorAsyncWrapper( + CommandFailureDetector(5, 1, error_types=[ConnectionError]) + ) + mock_db = Mock(spec=Database) + mock_db.circuit.state = CBState.CLOSED + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ("SET", "key1", "value1")) + await fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + await fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + await fd.register_failure(Exception(), ("SET", "key1", "value1")) + await fd.register_failure(Exception(), ("SET", "key1", "value1")) + + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + await fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + await fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + + assert mock_db.circuit.state == CBState.OPEN diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py new file mode 100644 index 0000000000..3e7ac42cd9 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -0,0 +1,401 @@ +import pytest +from mock.mock import AsyncMock, Mock + +from redis.asyncio.multidb.database import Database +from redis.asyncio.multidb.healthcheck import ( + EchoHealthCheck, + LagAwareHealthCheck, + HealthCheck, + HealthyAllPolicy, + HealthyMajorityPolicy, + HealthyAnyPolicy, +) +from redis.http.http_client import HttpError +from redis.multidb.circuit import State as CBState +from redis.exceptions import ConnectionError +from redis.multidb.exception import UnhealthyDatabaseException + + +@pytest.mark.onlynoncluster +class TestHealthyAllPolicy: + @pytest.mark.asyncio + async def test_policy_returns_true_for_all_successful_probes(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.return_value = True + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + assert await policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 3 + + @pytest.mark.asyncio + async def test_policy_returns_false_on_first_failed_probe(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [True, True, False] + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + assert not await policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + + @pytest.mark.asyncio + async def test_policy_raise_unhealthy_database_exception(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [True, True, ConnectionError] + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match="Unhealthy database"): + await policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + + +@pytest.mark.onlynoncluster +class TestHealthyMajorityPolicy: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", + [ + (3, [True, False, False], [True, True, True], 3, 0, False), + (3, [True, True, True], [True, False, False], 3, 3, False), + (3, [True, False, True], [True, True, True], 3, 3, True), + (3, [True, True, True], [True, False, True], 3, 3, True), + (3, [True, True, False], [True, False, True], 3, 3, True), + (4, [True, True, False, False], [True, True, True, True], 4, 0, False), + (4, [True, True, True, True], [True, True, False, False], 4, 4, False), + (4, [False, True, True, True], [True, True, True, True], 4, 4, True), + (4, [True, True, True, True], [True, False, True, True], 4, 4, True), + (4, [False, True, True, True], [True, True, False, True], 4, 4, True), + ], + ids=[ + "HC1 - no majority - odd", + "HC2 - no majority - odd", + "HC1 - majority- odd", + "HC2 - majority - odd", + "HC1 + HC2 - majority - odd", + "HC1 - no majority - even", + "HC2 - no majority - even", + "HC1 - majority - even", + "HC2 - majority - even", + "HC1 + HC2 - majority - even", + ], + ) + async def test_policy_returns_true_for_majority_successful_probes( + self, + probes, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result, + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyMajorityPolicy(probes, 0.01) + assert await policy.execute([mock_hc1, mock_hc2], mock_db) == expected_result + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count", + [ + (3, [True, ConnectionError, ConnectionError], [True, True, True], 3, 0), + (3, [True, True, True], [True, ConnectionError, ConnectionError], 3, 3), + ( + 4, + [True, ConnectionError, ConnectionError, True], + [True, True, True, True], + 3, + 0, + ), + ( + 4, + [True, True, True, True], + [True, ConnectionError, ConnectionError, False], + 4, + 3, + ), + ], + ids=[ + "HC1 - majority- odd", + "HC2 - majority - odd", + "HC1 - majority - even", + "HC2 - majority - even", + ], + ) + async def test_policy_raise_unhealthy_database_exception_on_majority_probes_exceptions( + self, probes, hc1_side_effect, hc2_side_effect, hc1_call_count, hc2_call_count + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match="Unhealthy database"): + await policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + +@pytest.mark.onlynoncluster +class TestHealthyAnyPolicy: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", + [ + ([False, False, False], [True, True, True], 3, 0, False), + ([False, False, True], [False, False, False], 3, 3, False), + ([False, True, True], [False, False, True], 2, 3, True), + ([True, True, True], [False, True, False], 1, 2, True), + ], + ids=[ + "HC1 - no successful", + "HC2 - no successful", + "HC1 - successful", + "HC2 - successful", + ], + ) + async def test_policy_returns_true_for_any_successful_probe( + self, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result, + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyAnyPolicy(3, 0.01) + assert await policy.execute([mock_hc1, mock_hc2], mock_db) == expected_result + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + @pytest.mark.asyncio + async def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_failed_health_check( + self, + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [False, False, ConnectionError] + mock_hc2.check_health.side_effect = [True, True, True] + mock_db = Mock(spec=Database) + + policy = HealthyAnyPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match="Unhealthy database"): + await policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + + +@pytest.mark.onlynoncluster +class TestEchoHealthCheck: + @pytest.mark.asyncio + async def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command = AsyncMock(side_effect=["healthcheck"]) + hc = EchoHealthCheck() + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) + assert mock_client.execute_command.call_count == 1 + + @pytest.mark.asyncio + async def test_database_is_unhealthy_on_incorrect_echo_response( + self, mock_client, mock_cb + ): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command = AsyncMock(side_effect=["wrong"]) + hc = EchoHealthCheck() + db = Database(mock_client, mock_cb, 0.9) + + assert not await hc.check_health(db) + assert mock_client.execute_command.call_count == 1 + + @pytest.mark.asyncio + async def test_database_close_circuit_on_successful_healthcheck( + self, mock_client, mock_cb + ): + mock_client.execute_command = AsyncMock(side_effect=["healthcheck"]) + mock_cb.state = CBState.HALF_OPEN + hc = EchoHealthCheck() + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) + assert mock_client.execute_command.call_count == 1 + + +@pytest.mark.onlynoncluster +class TestLagAwareHealthCheck: + @pytest.mark.asyncio + async def test_database_is_healthy_when_bdb_matches_by_dns_name( + self, mock_client, mock_cb + ): + """ + Ensures health check succeeds when /v1/bdbs contains an endpoint whose dns_name + matches database host, and availability endpoint returns success. + """ + host = "db1.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + # Mock HttpClient used inside LagAwareHealthCheck + mock_http = AsyncMock() + mock_http.get.side_effect = [ + # First call: list of bdbs + [ + { + "uid": "bdb-1", + "endpoints": [ + {"dns_name": host, "addr": ["10.0.0.1", "10.0.0.2"]}, + ], + } + ], + # Second call: availability check (no JSON expected) + None, + ] + + hc = LagAwareHealthCheck(rest_api_port=1234, lag_aware_tolerance=150) + # Inject our mocked http client + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + assert await hc.check_health(db) is True + # Base URL must be set correctly + assert hc._http_client.client.base_url == "https://healthcheck.example.com:1234" + # Calls: first to list bdbs, then to availability + assert mock_http.get.call_count == 2 + first_call = mock_http.get.call_args_list[0] + second_call = mock_http.get.call_args_list[1] + assert first_call.args[0] == "/v1/bdbs" + assert ( + second_call.args[0] + == "/v1/bdbs/bdb-1/availability?extend_check=lag&availability_lag_tolerance_ms=150" + ) + assert second_call.kwargs.get("expect_json") is False + + @pytest.mark.asyncio + async def test_database_is_healthy_when_bdb_matches_by_addr( + self, mock_client, mock_cb + ): + """ + Ensures health check succeeds when endpoint addr list contains the database host. + """ + host_ip = "203.0.113.5" + mock_client.get_connection_kwargs.return_value = {"host": host_ip} + + mock_http = AsyncMock() + mock_http.get.side_effect = [ + [ + { + "uid": "bdb-42", + "endpoints": [ + {"dns_name": "not-matching.example.com", "addr": [host_ip]}, + ], + } + ], + None, + ] + + hc = LagAwareHealthCheck() + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + assert await hc.check_health(db) is True + assert mock_http.get.call_count == 2 + assert ( + mock_http.get.call_args_list[1].args[0] + == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=5000" + ) + + @pytest.mark.asyncio + async def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): + """ + Ensures health check raises ValueError when there's no bdb matching the database host. + """ + host = "db2.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = AsyncMock() + # Return bdbs that do not match host by dns_name nor addr + mock_http.get.return_value = [ + { + "uid": "a", + "endpoints": [{"dns_name": "other.example.com", "addr": ["10.0.0.9"]}], + }, + { + "uid": "b", + "endpoints": [ + {"dns_name": "another.example.com", "addr": ["10.0.0.10"]} + ], + }, + ] + + hc = LagAwareHealthCheck() + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + with pytest.raises(ValueError, match="Could not find a matching bdb"): + await hc.check_health(db) + + # Only the listing call should have happened + mock_http.get.assert_called_once_with("/v1/bdbs") + + @pytest.mark.asyncio + async def test_propagates_http_error_from_availability(self, mock_client, mock_cb): + """ + Ensures that any HTTP error raised by the availability endpoint is propagated. + """ + host = "db3.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = AsyncMock() + # First: list bdbs -> match by dns_name + mock_http.get.side_effect = [ + [{"uid": "bdb-err", "endpoints": [{"dns_name": host, "addr": []}]}], + # Second: availability -> raise HttpError + HttpError( + url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", + status=503, + message="busy", + ), + ] + + hc = LagAwareHealthCheck() + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + with pytest.raises(HttpError, match="busy") as e: + await hc.check_health(db) + assert e.status == 503 + + # Ensure both calls were attempted + assert mock_http.get.call_count == 2 diff --git a/tests/test_asyncio/test_multidb/test_pipeline.py b/tests/test_asyncio/test_multidb/test_pipeline.py new file mode 100644 index 0000000000..48990bc62a --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_pipeline.py @@ -0,0 +1,385 @@ +import asyncio +from unittest.mock import Mock, AsyncMock, patch + +import pybreaker +import pytest + +from redis.asyncio.client import Pipeline +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy +from redis.asyncio.multidb.healthcheck import EchoHealthCheck +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +def mock_pipe() -> Pipeline: + mock_pipe = Mock(spec=Pipeline) + mock_pipe.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_pipe.__aexit__ = AsyncMock(return_value=None) + return mock_pipe + + +class TestPipeline: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_executes_pipeline_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + pipe = mock_pipe() + pipe.execute.return_value = ["OK1", "value1"] + mock_db1.client.pipeline.return_value = pipe + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + pipe = client.pipeline() + pipe.set("key1", "value1") + pipe.get("key1") + + assert await pipe.execute() == ["OK1", "value1"] + assert mock_hc.check_health.call_count == 9 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_pipeline_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + pipe = mock_pipe() + pipe.execute.return_value = ["OK1", "value1"] + mock_db1.client.pipeline.return_value = pipe + + mock_hc.check_health.side_effect = [ + False, + True, + True, + True, + True, + True, + True, + ] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + async with client.pipeline() as pipe: + pipe.set("key1", "value1") + pipe.get("key1") + + assert await pipe.execute() == ["OK1", "value1"] + assert mock_hc.check_health.call_count == 7 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_pipeline_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "error", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "error", + "error", + "healthcheck", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "error", + "error", + ] + + pipe = mock_pipe() + pipe.execute.return_value = ["OK", "value"] + mock_db.client.pipeline.return_value = pipe + + pipe1 = mock_pipe() + pipe1.execute.return_value = ["OK1", "value"] + mock_db1.client.pipeline.return_value = pipe1 + + pipe2 = mock_pipe() + pipe2.execute.return_value = ["OK2", "value"] + mock_db2.client.pipeline.return_value = pipe2 + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() + + client = MultiDBClient(mock_multi_db_config) + + async with client.pipeline() as pipe: + pipe.set("key1", "value") + pipe.get("key1") + + assert await pipe.execute() == ["OK1", "value"] + + await asyncio.sleep(0.15) + + async with client.pipeline() as pipe: + pipe.set("key1", "value") + pipe.get("key1") + + assert await pipe.execute() == ["OK2", "value"] + + await asyncio.sleep(0.1) + + async with client.pipeline() as pipe: + pipe.set("key1", "value") + pipe.get("key1") + + assert await pipe.execute() == ["OK", "value"] + + await asyncio.sleep(0.1) + + async with client.pipeline() as pipe: + pipe.set("key1", "value") + pipe.get("key1") + + assert await pipe.execute() == ["OK1", "value"] + + +class TestTransaction: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_executes_transaction_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.transaction.return_value = ["OK1", "value1"] + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + async def callback(pipe: Pipeline): + pipe.set("key1", "value1") + pipe.get("key1") + + assert await client.transaction(callback) == ["OK1", "value1"] + assert mock_hc.check_health.call_count == 9 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_transaction_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.transaction.return_value = ["OK1", "value1"] + + mock_hc.check_health.side_effect = [ + False, + True, + True, + True, + True, + True, + True, + ] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + async def callback(pipe: Pipeline): + pipe.set("key1", "value1") + pipe.get("key1") + + assert await client.transaction(callback) == ["OK1", "value1"] + assert mock_hc.check_health.call_count == 7 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_transaction_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "error", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "error", + "error", + "healthcheck", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "error", + "error", + ] + + mock_db.client.transaction.return_value = ["OK", "value"] + mock_db1.client.transaction.return_value = ["OK1", "value"] + mock_db2.client.transaction.return_value = ["OK2", "value"] + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() + + client = MultiDBClient(mock_multi_db_config) + + async def callback(pipe: Pipeline): + pipe.set("key1", "value1") + pipe.get("key1") + + assert await client.transaction(callback) == ["OK1", "value"] + await asyncio.sleep(0.15) + assert await client.transaction(callback) == ["OK2", "value"] + await asyncio.sleep(0.1) + assert await client.transaction(callback) == ["OK", "value"] + await asyncio.sleep(0.1) + assert await client.transaction(callback) == ["OK1", "value"] diff --git a/tests/test_asyncio/test_scenario/__init__.py b/tests/test_asyncio/test_scenario/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py new file mode 100644 index 0000000000..803445f508 --- /dev/null +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -0,0 +1,116 @@ +import asyncio +import os +from typing import Any, AsyncGenerator + +import pytest +import pytest_asyncio + +from redis.asyncio import Redis +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.config import ( + DatabaseConfig, + MultiDbConfig, +) +from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialBackoff +from redis.event import AsyncEventListenerInterface, EventDispatcher +from redis.multidb.failure_detector import DEFAULT_MIN_NUM_FAILURES +from tests.test_scenario.conftest import get_endpoints_config, extract_cluster_fqdn +from tests.test_scenario.fault_injector_client import FaultInjectorClient + + +class CheckActiveDatabaseChangedListener(AsyncEventListenerInterface): + def __init__(self): + self.is_changed_flag = False + + async def listen(self, event: AsyncActiveDatabaseChanged): + self.is_changed_flag = True + + +@pytest.fixture() +def fault_injector_client(): + url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") + return FaultInjectorClient(url) + + +@pytest_asyncio.fixture() +async def r_multi_db( + request, +) -> AsyncGenerator[tuple[MultiDBClient, CheckActiveDatabaseChangedListener, Any], Any]: + client_class = request.param.get("client_class", Redis) + + if client_class == Redis: + endpoint_config = get_endpoints_config("re-active-active") + else: + endpoint_config = get_endpoints_config("re-active-active-oss-cluster") + + username = endpoint_config.get("username", None) + password = endpoint_config.get("password", None) + min_num_failures = request.param.get("min_num_failures", DEFAULT_MIN_NUM_FAILURES) + command_retry = request.param.get( + "command_retry", Retry(ExponentialBackoff(cap=0.1, base=0.01), retries=10) + ) + + # Retry configuration different for health checks as initial health check require more time in case + # if infrastructure wasn't restored from the previous test. + health_check_interval = request.param.get("health_check_interval", 10) + health_checks = request.param.get("health_checks", []) + event_dispatcher = EventDispatcher() + listener = CheckActiveDatabaseChangedListener() + event_dispatcher.register_listeners( + { + AsyncActiveDatabaseChanged: [listener], + } + ) + db_configs = [] + + db_config = DatabaseConfig( + weight=1.0, + from_url=endpoint_config["endpoints"][0], + client_kwargs={ + "username": username, + "password": password, + "decode_responses": True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config["endpoints"][0]), + ) + db_configs.append(db_config) + + db_config1 = DatabaseConfig( + weight=0.9, + from_url=endpoint_config["endpoints"][1], + client_kwargs={ + "username": username, + "password": password, + "decode_responses": True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config["endpoints"][1]), + ) + db_configs.append(db_config1) + + config = MultiDbConfig( + client_class=client_class, + databases_config=db_configs, + command_retry=command_retry, + min_num_failures=min_num_failures, + health_checks=health_checks, + health_check_probes=3, + health_check_interval=health_check_interval, + event_dispatcher=event_dispatcher, + ) + + client = MultiDBClient(config) + + async def teardown(): + await client.aclose() + + if client.command_executor.active_database and isinstance( + client.command_executor.active_database.client, Redis + ): + await client.command_executor.active_database.client.connection_pool.disconnect() + + await asyncio.sleep(10) + + yield client, listener, endpoint_config + await teardown() diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py new file mode 100644 index 0000000000..c33e482050 --- /dev/null +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -0,0 +1,420 @@ +import asyncio +import json +import logging +import os + +import pytest + +from redis.asyncio import RedisCluster +from redis.asyncio.client import Pipeline, Redis +from redis.asyncio.multidb.failover import ( + DEFAULT_FAILOVER_ATTEMPTS, + DEFAULT_FAILOVER_DELAY, +) +from redis.asyncio.multidb.healthcheck import LagAwareHealthCheck +from redis.asyncio.retry import Retry +from redis.backoff import ConstantBackoff +from redis.multidb.exception import TemporaryUnavailableException +from redis.utils import dummy_fail_async +from tests.test_scenario.fault_injector_client import ActionRequest, ActionType + +logger = logging.getLogger(__name__) + + +async def trigger_network_failure_action( + fault_injector_client, config, event: asyncio.Event = None +): + action_request = ActionRequest( + action_type=ActionType.NETWORK_FAILURE, + parameters={"bdb_id": config["bdb_id"], "delay": 3, "cluster_index": 0}, + ) + + result = fault_injector_client.trigger_action(action_request) + status_result = fault_injector_client.get_action_status(result["action_id"]) + + while status_result["status"] != "success": + await asyncio.sleep(0.1) + status_result = fault_injector_client.get_action_status(result["action_id"]) + logger.info( + f"Waiting for action to complete. Status: {status_result['status']}" + ) + + if event: + event.set() + + logger.info(f"Action completed. Status: {status_result['status']}") + + +class TestActiveActive: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, + ], + ids=["standalone", "cluster"], + indirect=True, + ) + @pytest.mark.timeout(200) + async def test_multi_db_client_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): + client, listener, endpoint_config = r_multi_db + + # Handle unavailable databases from previous test. + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), + ) + + async with client as r_multi_db: + event = asyncio.Event() + asyncio.create_task( + trigger_network_failure_action( + fault_injector_client, endpoint_config, event + ) + ) + + await retry.call_with_retry( + lambda: r_multi_db.set("key", "value"), lambda _: dummy_fail_async() + ) + + # Execute commands before network failure + while not event.is_set(): + assert ( + await retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail_async() + ) + == "value" + ) + await asyncio.sleep(0.5) + + # Execute commands until database failover + while not listener.is_changed_flag: + assert ( + await retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail_async() + ) + == "value" + ) + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [ + { + "client_class": Redis, + "min_num_failures": 2, + "health_checks": [ + LagAwareHealthCheck( + verify_tls=False, + auth_basic=( + os.getenv("ENV0_USERNAME"), + os.getenv("ENV0_PASSWORD"), + ), + ) + ], + "health_check_interval": 20, + }, + { + "client_class": RedisCluster, + "min_num_failures": 2, + "health_checks": [ + LagAwareHealthCheck( + verify_tls=False, + auth_basic=( + os.getenv("ENV0_USERNAME"), + os.getenv("ENV0_PASSWORD"), + ), + ) + ], + "health_check_interval": 20, + }, + ], + ids=["standalone", "cluster"], + indirect=True, + ) + @pytest.mark.timeout(200) + async def test_multi_db_client_uses_lag_aware_health_check( + self, r_multi_db, fault_injector_client + ): + client, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), + ) + + async with client as r_multi_db: + event = asyncio.Event() + asyncio.create_task( + trigger_network_failure_action( + fault_injector_client, endpoint_config, event + ) + ) + + await retry.call_with_retry( + lambda: r_multi_db.set("key", "value"), lambda _: dummy_fail_async() + ) + + # Execute commands before network failure + while not event.is_set(): + assert ( + await retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail_async() + ) + == "value" + ) + await asyncio.sleep(0.5) + + # Execute commands after network failure + while not listener.is_changed_flag: + assert ( + await retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail_async() + ) + == "value" + ) + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, + ], + ids=["standalone", "cluster"], + indirect=True, + ) + @pytest.mark.timeout(200) + async def test_context_manager_pipeline_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): + client, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), + ) + + async def callback(): + async with r_multi_db.pipeline() as pipe: + pipe.set("{hash}key1", "value1") + pipe.set("{hash}key2", "value2") + pipe.set("{hash}key3", "value3") + pipe.get("{hash}key1") + pipe.get("{hash}key2") + pipe.get("{hash}key3") + assert await pipe.execute() == [ + True, + True, + True, + "value1", + "value2", + "value3", + ] + + async with client as r_multi_db: + event = asyncio.Event() + asyncio.create_task( + trigger_network_failure_action( + fault_injector_client, endpoint_config, event + ) + ) + + # Execute pipeline before network failure + while not event.is_set(): + await retry.call_with_retry( + lambda: callback(), lambda _: dummy_fail_async() + ) + await asyncio.sleep(0.5) + + # Execute commands until database failover + while not listener.is_changed_flag: + await retry.call_with_retry( + lambda: callback(), lambda _: dummy_fail_async() + ) + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, + ], + ids=["standalone", "cluster"], + indirect=True, + ) + @pytest.mark.timeout(200) + async def test_chaining_pipeline_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): + client, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), + ) + + async def callback(): + pipe = r_multi_db.pipeline() + pipe.set("{hash}key1", "value1") + pipe.set("{hash}key2", "value2") + pipe.set("{hash}key3", "value3") + pipe.get("{hash}key1") + pipe.get("{hash}key2") + pipe.get("{hash}key3") + assert await pipe.execute() == [ + True, + True, + True, + "value1", + "value2", + "value3", + ] + + async with client as r_multi_db: + event = asyncio.Event() + asyncio.create_task( + trigger_network_failure_action( + fault_injector_client, endpoint_config, event + ) + ) + + # Execute pipeline before network failure + while not event.is_set(): + await retry.call_with_retry( + lambda: callback(), lambda _: dummy_fail_async() + ) + await asyncio.sleep(0.5) + + # Execute pipeline until database failover + while not listener.is_changed_flag: + await retry.call_with_retry( + lambda: callback(), lambda _: dummy_fail_async() + ) + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, + ], + ids=["standalone", "cluster"], + indirect=True, + ) + @pytest.mark.timeout(200) + async def test_transaction_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): + client, listener, endpoint_config = r_multi_db + + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), + ) + + async def callback(pipe: Pipeline): + pipe.set("{hash}key1", "value1") + pipe.set("{hash}key2", "value2") + pipe.set("{hash}key3", "value3") + pipe.get("{hash}key1") + pipe.get("{hash}key2") + pipe.get("{hash}key3") + + async with client as r_multi_db: + event = asyncio.Event() + asyncio.create_task( + trigger_network_failure_action( + fault_injector_client, endpoint_config, event + ) + ) + + # Execute transaction before network failure + while not event.is_set(): + await retry.call_with_retry( + lambda: r_multi_db.transaction(callback), + lambda _: dummy_fail_async(), + ) + await asyncio.sleep(0.5) + + # Execute transaction until database failover + while not listener.is_changed_flag: + assert await retry.call_with_retry( + lambda: r_multi_db.transaction(callback), + lambda _: dummy_fail_async(), + ) == [True, True, True, "value1", "value2", "value3"] + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize("r_multi_db", [{"min_num_failures": 2}], indirect=True) + @pytest.mark.timeout(200) + async def test_pubsub_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): + client, listener, endpoint_config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), + ) + + data = json.dumps({"message": "test"}) + messages_count = 0 + + async def handler(message): + nonlocal messages_count + messages_count += 1 + + async with client as r_multi_db: + event = asyncio.Event() + asyncio.create_task( + trigger_network_failure_action( + fault_injector_client, endpoint_config, event + ) + ) + + pubsub = await r_multi_db.pubsub() + + # Assign a handler and run in a separate thread. + await retry.call_with_retry( + lambda: pubsub.subscribe(**{"test-channel": handler}), + lambda _: dummy_fail_async(), + ) + task = asyncio.create_task(pubsub.run(poll_timeout=0.1)) + + # Execute publish before network failure + while not event.is_set(): + await retry.call_with_retry( + lambda: r_multi_db.publish("test-channel", data), + lambda _: dummy_fail_async(), + ) + await asyncio.sleep(0.5) + + # Execute publish until database failover + while not listener.is_changed_flag: + await retry.call_with_retry( + lambda: r_multi_db.publish("test-channel", data), + lambda _: dummy_fail_async(), + ) + await asyncio.sleep(0.5) + + # After db changed still generates some traffic. + for _ in range(5): + await retry.call_with_retry( + lambda: r_multi_db.publish("test-channel", data), + lambda _: dummy_fail_async(), + ) + + # A timeout to ensure that an async handler will handle all previous messages. + await asyncio.sleep(0.1) + task.cancel() + assert messages_count >= 2 diff --git a/tests/test_background.py b/tests/test_background.py new file mode 100644 index 0000000000..bac9c1eef6 --- /dev/null +++ b/tests/test_background.py @@ -0,0 +1,94 @@ +import asyncio +from time import sleep + +import pytest + +from redis.background import BackgroundScheduler + + +class TestBackgroundScheduler: + def test_run_once(self): + execute_counter = 0 + one = "arg1" + two = 9999 + + def callback(arg1: str, arg2: int): + nonlocal execute_counter + nonlocal one + nonlocal two + + execute_counter += 1 + + assert arg1 == one + assert arg2 == two + + scheduler = BackgroundScheduler() + scheduler.run_once(0.1, callback, one, two) + assert execute_counter == 0 + + sleep(0.15) + + assert execute_counter == 1 + + @pytest.mark.parametrize( + "interval,timeout,call_count", + [ + (0.012, 0.04, 3), + (0.035, 0.04, 1), + (0.045, 0.04, 0), + ], + ) + def test_run_recurring(self, interval, timeout, call_count): + execute_counter = 0 + one = "arg1" + two = 9999 + + def callback(arg1: str, arg2: int): + nonlocal execute_counter + nonlocal one + nonlocal two + + execute_counter += 1 + + assert arg1 == one + assert arg2 == two + + scheduler = BackgroundScheduler() + scheduler.run_recurring(interval, callback, one, two) + assert execute_counter == 0 + + sleep(timeout) + + assert execute_counter == call_count + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "interval,timeout,call_count", + [ + (0.012, 0.04, 3), + (0.035, 0.04, 1), + (0.045, 0.04, 0), + ], + ) + async def test_run_recurring_async(self, interval, timeout, call_count): + execute_counter = 0 + one = "arg1" + two = 9999 + + async def callback(arg1: str, arg2: int): + nonlocal execute_counter + nonlocal one + nonlocal two + + execute_counter += 1 + + assert arg1 == one + assert arg2 == two + + scheduler = BackgroundScheduler() + await scheduler.run_recurring_async(interval, callback, one, two) + assert execute_counter == 0 + + await asyncio.sleep(timeout) + + assert execute_counter == call_count diff --git a/tests/test_data_structure.py b/tests/test_data_structure.py new file mode 100644 index 0000000000..0911466e58 --- /dev/null +++ b/tests/test_data_structure.py @@ -0,0 +1,94 @@ +import concurrent +import random +from concurrent.futures import ThreadPoolExecutor +from time import sleep + +from redis.data_structure import WeightedList + + +class TestWeightedList: + def test_add_items(self): + wlist = WeightedList() + + wlist.add("item1", 3.0) + wlist.add("item2", 2.0) + wlist.add("item3", 4.0) + wlist.add("item4", 4.0) + + assert wlist.get_top_n(4) == [ + ("item3", 4.0), + ("item4", 4.0), + ("item1", 3.0), + ("item2", 2.0), + ] + + def test_remove_items(self): + wlist = WeightedList() + wlist.add("item1", 3.0) + wlist.add("item2", 2.0) + wlist.add("item3", 4.0) + wlist.add("item4", 4.0) + + assert wlist.remove("item2") == 2.0 + assert wlist.remove("item4") == 4.0 + + assert wlist.get_top_n(4) == [("item3", 4.0), ("item1", 3.0)] + + def test_get_by_weight_range(self): + wlist = WeightedList() + wlist.add("item1", 3.0) + wlist.add("item2", 2.0) + wlist.add("item3", 4.0) + wlist.add("item4", 4.0) + + assert wlist.get_by_weight_range(2.0, 3.0) == [("item1", 3.0), ("item2", 2.0)] + + def test_update_weights(self): + wlist = WeightedList() + wlist.add("item1", 3.0) + wlist.add("item2", 2.0) + wlist.add("item3", 4.0) + wlist.add("item4", 4.0) + + assert wlist.get_top_n(4) == [ + ("item3", 4.0), + ("item4", 4.0), + ("item1", 3.0), + ("item2", 2.0), + ] + + wlist.update_weight("item2", 5.0) + + assert wlist.get_top_n(4) == [ + ("item2", 5.0), + ("item3", 4.0), + ("item4", 4.0), + ("item1", 3.0), + ] + + def test_thread_safety(self) -> None: + """Test thread safety with concurrent operations""" + wl = WeightedList() + + def worker(worker_id): + for i in range(100): + # Add items + wl.add(f"item_{worker_id}_{i}", random.uniform(0, 100)) + + # Read operations + try: + length = len(wl) + if length > 0: + wl.get_top_n(min(5, length)) + wl.get_by_weight_range(20, 80) + except Exception as e: + print(f"Error in worker {worker_id}: {e}") + + sleep(0.001) # Small delay + + # Run multiple workers concurrently + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + concurrent.futures.wait(futures) + + assert len(wl) == 500 diff --git a/tests/test_event.py b/tests/test_event.py new file mode 100644 index 0000000000..0caab04e78 --- /dev/null +++ b/tests/test_event.py @@ -0,0 +1,67 @@ +from unittest.mock import Mock, AsyncMock + +from redis.event import ( + EventListenerInterface, + EventDispatcher, + AsyncEventListenerInterface, +) + + +class TestEventDispatcher: + def test_register_listeners(self): + mock_event = Mock(spec=object) + mock_event_listener = Mock(spec=EventListenerInterface) + listener_called = 0 + + def callback(event): + nonlocal listener_called + listener_called += 1 + + mock_event_listener.listen = callback + + # Register via constructor + dispatcher = EventDispatcher( + event_listeners={type(mock_event): [mock_event_listener]} + ) + dispatcher.dispatch(mock_event) + + assert listener_called == 1 + + # Register additional listener for the same event + mock_another_event_listener = Mock(spec=EventListenerInterface) + mock_another_event_listener.listen = callback + dispatcher.register_listeners( + mappings={type(mock_event): [mock_another_event_listener]} + ) + dispatcher.dispatch(mock_event) + + assert listener_called == 3 + + async def test_register_listeners_async(self): + mock_event = Mock(spec=object) + mock_event_listener = AsyncMock(spec=AsyncEventListenerInterface) + listener_called = 0 + + async def callback(event): + nonlocal listener_called + listener_called += 1 + + mock_event_listener.listen = callback + + # Register via constructor + dispatcher = EventDispatcher( + event_listeners={type(mock_event): [mock_event_listener]} + ) + await dispatcher.dispatch_async(mock_event) + + assert listener_called == 1 + + # Register additional listener for the same event + mock_another_event_listener = Mock(spec=AsyncEventListenerInterface) + mock_another_event_listener.listen = callback + dispatcher.register_listeners( + mappings={type(mock_event): [mock_another_event_listener]} + ) + await dispatcher.dispatch_async(mock_event) + + assert listener_called == 3 diff --git a/tests/test_http/__init__.py b/tests/test_http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_http/test_http_client.py b/tests/test_http/test_http_client.py new file mode 100644 index 0000000000..5dc1cf1631 --- /dev/null +++ b/tests/test_http/test_http_client.py @@ -0,0 +1,371 @@ +import json +import gzip +from io import BytesIO +from typing import Any, Dict +from urllib.error import HTTPError +from urllib.parse import urlparse, parse_qs + +import pytest + +from redis.backoff import ExponentialWithJitterBackoff +from redis.http.http_client import HttpClient, HttpError +from redis.retry import Retry + + +class FakeResponse: + def __init__( + self, *, status: int, headers: Dict[str, str], url: str, content: bytes + ): + self.status = status + self.headers = headers + self._url = url + self._content = content + + def read(self) -> bytes: + return self._content + + def geturl(self) -> str: + return self._url + + # Support context manager used by urlopen + def __enter__(self) -> "FakeResponse": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + +class TestHttpClient: + def test_get_returns_parsed_json_and_uses_timeout( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/items" + params = {"limit": 5, "q": "hello world"} + expected_url = f"{base_url}{path}?limit=5&q=hello+world" + payload: Dict[str, Any] = {"items": [1, 2, 3], "ok": True} + content = json.dumps(payload).encode("utf-8") + + captured_kwargs = {} + + def fake_urlopen(request, *, timeout=None, context=None): + # Capture call details for assertions + captured_kwargs["timeout"] = timeout + captured_kwargs["context"] = context + # Assert the request was constructed correctly + assert getattr(request, "method", "").upper() == "GET" + assert request.full_url == expected_url + # Return a successful response + return FakeResponse( + status=200, + headers={"Content-Type": "application/json; charset=utf-8"}, + url=expected_url, + content=content, + ) + + # Patch the urlopen used inside HttpClient + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + + # Act + result = client.get( + path, params=params, timeout=12.34 + ) # default expect_json=True + + # Assert + assert result == payload + assert pytest.approx(captured_kwargs["timeout"], rel=1e-6) == 12.34 + # HTTPS -> a context should be provided (created by ssl.create_default_context) + assert captured_kwargs["context"] is not None + + def test_get_handles_gzip_response(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "gzip-endpoint" + expected_url = f"{base_url}{path}" + payload = {"message": "compressed ok"} + raw = json.dumps(payload).encode("utf-8") + gzipped = gzip.compress(raw) + + def fake_urlopen(request, *, timeout=None, context=None): + # Return gzipped content with appropriate header + return FakeResponse( + status=200, + headers={ + "Content-Type": "application/json; charset=utf-8", + "Content-Encoding": "gzip", + }, + url=expected_url, + content=gzipped, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + + # Act + result = client.get(path) # expect_json=True by default + + # Assert + assert result == payload + + def test_get_retries_on_retryable_http_errors_and_succeeds( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Arrange: configure limited retries so we can assert attempts + retry_policy = Retry( + backoff=ExponentialWithJitterBackoff(base=0, cap=0), retries=2 + ) # 2 retries -> up to 3 attempts + base_url = "https://api.example.com/" + path = "sometimes-busy" + expected_url = f"{base_url}{path}" + payload = {"ok": True} + success_content = json.dumps(payload).encode("utf-8") + + call_count = {"n": 0} + + def make_http_error(url: str, code: int, body: bytes = b"busy"): + # Provide a file-like object for .read() when HttpClient tries to read error content + fp = BytesIO(body) + return HTTPError( + url=url, + code=code, + msg="Service Unavailable", + hdrs={"Content-Type": "text/plain"}, + fp=fp, + ) + + def flaky_urlopen(request, *, timeout=None, context=None): + call_count["n"] += 1 + # Fail with a retryable status (503) for the first two calls, then succeed + if call_count["n"] <= 2: + raise make_http_error(expected_url, 503) + return FakeResponse( + status=200, + headers={"Content-Type": "application/json; charset=utf-8"}, + url=expected_url, + content=success_content, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", flaky_urlopen) + + client = HttpClient(base_url=base_url, retry=retry_policy) + + # Act + result = client.get(path) + + # Assert: should have retried twice (total 3 attempts) and finally returned parsed JSON + assert result == payload + assert call_count["n"] == retry_policy.get_retries() + 1 + + def test_post_sends_json_body_and_parses_response( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/create" + expected_url = f"{base_url}{path}" + send_payload = {"a": 1, "b": "x"} + recv_payload = {"id": 10, "ok": True} + recv_content = json.dumps(recv_payload, separators=(",", ":")).encode("utf-8") + + def fake_urlopen(request, *, timeout=None, context=None): + # Verify method, URL and headers + assert getattr(request, "method", "").upper() == "POST" + assert request.full_url == expected_url + # Content-Type should be auto-set for string JSON body + assert ( + request.headers.get("Content-type") == "application/json; charset=utf-8" + ) + # Body should be already UTF-8 encoded JSON with no spaces + assert request.data == json.dumps( + send_payload, ensure_ascii=False, separators=(",", ":") + ).encode("utf-8") + return FakeResponse( + status=200, + headers={"Content-Type": "application/json; charset=utf-8"}, + url=expected_url, + content=recv_content, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + + # Act + result = client.post(path, json_body=send_payload) + + # Assert + assert result == recv_payload + + def test_post_with_raw_data_and_custom_headers( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "upload" + expected_url = f"{base_url}{path}" + raw_data = b"\x00\x01BINARY" + custom_headers = {"Content-type": "application/octet-stream", "X-extra": "1"} + recv_payload = {"status": "ok"} + + def fake_urlopen(request, *, timeout=None, context=None): + assert getattr(request, "method", "").upper() == "POST" + assert request.full_url == expected_url + # Ensure our provided headers are present + assert request.headers.get("Content-type") == "application/octet-stream" + assert request.headers.get("X-extra") == "1" + assert request.data == raw_data + return FakeResponse( + status=200, + headers={"Content-Type": "application/json"}, + url=expected_url, + content=json.dumps(recv_payload).encode("utf-8"), + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + # Act + result = client.post(path, data=raw_data, headers=custom_headers) + + # Assert + assert result == recv_payload + + def test_delete_returns_http_response_when_expect_json_false( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/resource/42" + expected_url = f"{base_url}{path}" + body = b"deleted" + + def fake_urlopen(request, *, timeout=None, context=None): + assert getattr(request, "method", "").upper() == "DELETE" + assert request.full_url == expected_url + return FakeResponse( + status=204, + headers={"Content-Type": "text/plain"}, + url=expected_url, + content=body, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + client = HttpClient(base_url=base_url) + + # Act + resp = client.delete(path, expect_json=False) + + # Assert + assert resp.status == 204 + assert resp.url == expected_url + assert resp.content == body + + def test_put_raises_http_error_on_non_success( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/update/1" + expected_url = f"{base_url}{path}" + + def make_http_error(url: str, code: int, body: bytes = b"not found"): + fp = BytesIO(body) + return HTTPError( + url=url, + code=code, + msg="Not Found", + hdrs={"Content-Type": "text/plain"}, + fp=fp, + ) + + def fake_urlopen(request, *, timeout=None, context=None): + raise make_http_error(expected_url, 404) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + client = HttpClient(base_url=base_url) + + # Act / Assert + with pytest.raises(HttpError) as exc: + client.put(path, json_body={"x": 1}) + assert exc.value.status == 404 + assert exc.value.url == expected_url + + def test_patch_with_params_encodes_query( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/edit" + params = {"tag": ["a", "b"], "q": "hello world"} + + captured_url = {"u": None} + + def fake_urlopen(request, *, timeout=None, context=None): + captured_url["u"] = request.full_url + return FakeResponse( + status=200, + headers={"Content-Type": "application/json"}, + url=request.full_url, + content=json.dumps({"ok": True}).encode("utf-8"), + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + client.patch(path, params=params) # We don't care about response here + + # Assert query parameters regardless of ordering + parsed = urlparse(captured_url["u"]) + qs = parse_qs(parsed.query) + assert qs["q"] == ["hello world"] + assert qs["tag"] == ["a", "b"] + + def test_request_low_level_headers_auth_and_timeout_default( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + # Arrange: use plain HTTP to verify no TLS context, and check default timeout used + base_url = "http://example.com/" + path = "ping" + captured = { + "timeout": None, + "context": "unset", + "headers": None, + "method": None, + } + + def fake_urlopen(request, *, timeout=None, context=None): + captured["timeout"] = timeout + captured["context"] = context + captured["headers"] = dict(request.headers) + captured["method"] = getattr(request, "method", "").upper() + return FakeResponse( + status=200, + headers={"Content-Type": "application/json"}, + url=request.full_url, + content=json.dumps({"pong": True}).encode("utf-8"), + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url, auth_basic=("user", "pass")) + resp = client.request("GET", path) + + # Assert + assert resp.status == 200 + assert captured["method"] == "GET" + assert captured["context"] is None # no TLS for http + assert ( + pytest.approx(captured["timeout"], rel=1e-6) == client.timeout + ) # default used + # Check some default headers and Authorization presence + headers = {k.lower(): v for k, v in captured["headers"].items()} + assert "authorization" in headers and headers["authorization"].startswith( + "Basic " + ) + assert headers.get("accept") == "application/json" + assert "gzip" in headers.get("accept-encoding", "").lower() + assert "user-agent" in headers diff --git a/tests/test_maint_notifications_handling.py b/tests/test_maint_notifications_handling.py index baa7d601fa..54b6e2dff7 100644 --- a/tests/test_maint_notifications_handling.py +++ b/tests/test_maint_notifications_handling.py @@ -330,6 +330,9 @@ def setsockopt(self, level, optname, value): """Simulate setting socket options.""" pass + def setblocking(self, blocking): + pass + def getpeername(self): """Simulate getting peer name.""" return self.address diff --git a/tests/test_multidb/__init__.py b/tests/test_multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py new file mode 100644 index 0000000000..f6ee6d3ec4 --- /dev/null +++ b/tests/test_multidb/conftest.py @@ -0,0 +1,131 @@ +from unittest.mock import Mock + +import pytest + +from redis import Redis, ConnectionPool +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState, CircuitBreaker +from redis.multidb.config import ( + MultiDbConfig, + DatabaseConfig, + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_AUTO_FALLBACK_INTERVAL, +) +from redis.multidb.database import Database, Databases +from redis.multidb.failover import FailoverStrategy +from redis.multidb.failure_detector import FailureDetector +from redis.multidb.healthcheck import ( + HealthCheck, + DEFAULT_HEALTH_CHECK_PROBES, + DEFAULT_HEALTH_CHECK_POLICY, +) + + +@pytest.fixture() +def mock_client() -> Redis: + return Mock(spec=Redis) + + +@pytest.fixture() +def mock_cb() -> CircuitBreaker: + return Mock(spec=CircuitBreaker) + + +@pytest.fixture() +def mock_fd() -> FailureDetector: + return Mock(spec=FailureDetector) + + +@pytest.fixture() +def mock_fs() -> FailoverStrategy: + return Mock(spec=FailoverStrategy) + + +@pytest.fixture() +def mock_hc() -> HealthCheck: + return Mock(spec=HealthCheck) + + +@pytest.fixture() +def mock_db(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + + +@pytest.fixture() +def mock_db1(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + + +@pytest.fixture() +def mock_db2(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + db.client.connection_pool = Mock(spec=ConnectionPool) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=CircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + + +@pytest.fixture() +def mock_multi_db_config(request, mock_fd, mock_fs, mock_hc, mock_ed) -> MultiDbConfig: + hc_interval = request.param.get("hc_interval", DEFAULT_HEALTH_CHECK_INTERVAL) + auto_fallback_interval = request.param.get( + "auto_fallback_interval", DEFAULT_AUTO_FALLBACK_INTERVAL + ) + health_check_policy = request.param.get( + "health_check_policy", DEFAULT_HEALTH_CHECK_POLICY + ) + health_check_probes = request.param.get( + "health_check_probes", DEFAULT_HEALTH_CHECK_PROBES + ) + + config = MultiDbConfig( + databases_config=[Mock(spec=DatabaseConfig)], + failure_detectors=[mock_fd], + health_check_interval=hc_interval, + health_check_delay=0.05, + health_check_policy=health_check_policy, + health_check_probes=health_check_probes, + failover_strategy=mock_fs, + auto_fallback_interval=auto_fallback_interval, + event_dispatcher=mock_ed, + ) + + return config + + +def create_weighted_list(*databases) -> Databases: + dbs = WeightedList() + + for db in databases: + dbs.add(db, db.weight) + + return dbs diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py new file mode 100644 index 0000000000..7d0f2cb700 --- /dev/null +++ b/tests/test_multidb/test_circuit.py @@ -0,0 +1,57 @@ +import pybreaker +import pytest + +from redis.multidb.circuit import ( + PBCircuitBreakerAdapter, + State as CbState, + CircuitBreaker, +) + + +@pytest.mark.onlynoncluster +class TestPBCircuitBreaker: + @pytest.mark.parametrize( + "mock_db", + [ + {"weight": 0.7, "circuit": {"state": CbState.CLOSED}}, + ], + indirect=True, + ) + def test_cb_correctly_configured(self, mock_db): + pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) + adapter = PBCircuitBreakerAdapter(cb=pb_circuit) + assert adapter.state == CbState.CLOSED + + adapter.state = CbState.OPEN + assert adapter.state == CbState.OPEN + + adapter.state = CbState.HALF_OPEN + assert adapter.state == CbState.HALF_OPEN + + adapter.state = CbState.CLOSED + assert adapter.state == CbState.CLOSED + + assert adapter.grace_period == 5 + adapter.grace_period = 10 + + assert adapter.grace_period == 10 + + adapter.database = mock_db + assert adapter.database == mock_db + + def test_cb_executes_callback_on_state_changed(self): + pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) + adapter = PBCircuitBreakerAdapter(cb=pb_circuit) + called_count = 0 + + def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): + nonlocal called_count + assert old_state == CbState.CLOSED + assert new_state == CbState.HALF_OPEN + assert isinstance(cb, PBCircuitBreakerAdapter) + called_count += 1 + + adapter.on_state_changed(callback) + adapter.state = CbState.HALF_OPEN + + assert called_count == 1 diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py new file mode 100644 index 0000000000..5ea2193895 --- /dev/null +++ b/tests/test_multidb/test_client.py @@ -0,0 +1,561 @@ +from time import sleep +from unittest.mock import patch, Mock + +import pybreaker +import pytest + +from redis.event import EventDispatcher, OnCommandsFailEvent +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.database import SyncDatabase +from redis.multidb.client import MultiDBClient +from redis.multidb.exception import NoValidDatabaseException +from redis.multidb.failover import WeightBasedFailoverStrategy +from redis.multidb.failure_detector import FailureDetector +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from tests.test_multidb.conftest import create_weighted_list + + +@pytest.mark.onlynoncluster +class TestMultiDbClient: + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_against_correct_db_on_successful_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 9 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.CLOSED + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_execute_command_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + + mock_hc.check_health.side_effect = [ + False, + True, + True, + True, + True, + True, + True, + ] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 7 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "OK", + "error", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "OK1", + "error", + "error", + "healthcheck", + "OK1", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "OK2", + "error", + "error", + ] + mock_multi_db_config.health_check_interval = 0.2 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() + + client = MultiDBClient(mock_multi_db_config) + assert client.set("key", "value") == "OK1" + sleep(0.3) + assert client.set("key", "value") == "OK2" + sleep(0.2) + assert client.set("key", "value") == "OK" + sleep(0.2) + assert client.set("key", "value") == "OK1" + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_auto_fallback_to_highest_weight_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + mock_hc.check_health.side_effect = [ + True, + True, + True, + False, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + ] + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[mock_hc], + ), + ): + mock_db.client.execute_command.return_value = "OK" + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.auto_fallback_interval = 0.5 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() + + client = MultiDBClient(mock_multi_db_config) + assert client.set("key", "value") == "OK1" + sleep(0.18) + assert client.set("key", "value") == "OK2" + sleep(0.5) + assert client.set("key", "value") == "OK1" + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.5, "circuit": {"state": CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_execute_command_throws_exception_on_failed_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises( + NoValidDatabaseException, + match="Initial connection failed - no active database found", + ): + client.set("key", "value") + + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_database_throws_exception_on_same_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(ValueError, match="Given database already exists"): + client.add_database(mock_db) + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_database_makes_new_database_active( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert client.set("key", "value") == "OK2" + assert mock_hc.check_health.call_count == 6 + + client.add_database(mock_db1) + assert mock_hc.check_health.call_count == 9 + + assert client.set("key", "value") == "OK1" + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_remove_highest_weighted_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 9 + + client.remove_database(mock_db1) + + assert client.set("key", "value") == "OK2" + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_update_database_weight_to_be_highest( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 9 + + client.update_database_weight(mock_db2, 0.8) + assert mock_db2.weight == 0.8 + + assert client.set("key", "value") == "OK2" + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_new_failure_detector( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_multi_db_config.event_dispatcher = EventDispatcher() + mock_fd = mock_multi_db_config.failure_detectors[0] + + # Event fired if command against mock_db1 would fail + command_fail_event = OnCommandsFailEvent( + commands=("SET", "key", "value"), + exception=Exception(), + ) + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 9 + + # Simulate failing command events that lead to a failure detection + for i in range(5): + mock_multi_db_config.event_dispatcher.dispatch(command_fail_event) + + assert mock_fd.register_failure.call_count == 5 + + another_fd = Mock(spec=FailureDetector) + client.add_failure_detector(another_fd) + + # Simulate failing command events that lead to a failure detection + for i in range(5): + mock_multi_db_config.event_dispatcher.dispatch(command_fail_event) + + assert mock_fd.register_failure.call_count == 10 + assert another_fd.register_failure.call_count == 5 + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_add_new_health_check( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 9 + + another_hc = Mock(spec=HealthCheck) + another_hc.check_health.return_value = True + + client.add_health_check(another_hc) + client._check_db_health(mock_db1) + + assert mock_hc.check_health.call_count == 12 + assert another_hc.check_health.call_count == 3 + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_set_active_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db.client.execute_command.return_value = "OK" + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert client.set("key", "value") == "OK1" + assert mock_hc.check_health.call_count == 9 + + client.set_active_database(mock_db) + assert client.set("key", "value") == "OK" + + with pytest.raises( + ValueError, match="Given database is not a member of database list" + ): + client.set_active_database(Mock(spec=SyncDatabase)) + + mock_hc.check_health.return_value = False + + with pytest.raises( + NoValidDatabaseException, + match="Cannot set active database, database is unhealthy", + ): + client.set_active_database(mock_db1) diff --git a/tests/test_multidb/test_command_executor.py b/tests/test_multidb/test_command_executor.py new file mode 100644 index 0000000000..43e5f47344 --- /dev/null +++ b/tests/test_multidb/test_command_executor.py @@ -0,0 +1,173 @@ +from time import sleep + +import pytest + +from redis.exceptions import ConnectionError +from redis.backoff import NoBackoff +from redis.event import EventDispatcher +from redis.multidb.circuit import State as CBState +from redis.multidb.command_executor import DefaultCommandExecutor +from redis.multidb.failure_detector import CommandFailureDetector +from redis.retry import Retry +from tests.test_multidb.conftest import create_weighted_list + + +@pytest.mark.onlynoncluster +class TestDefaultCommandExecutor: + @pytest.mark.parametrize( + "mock_db,mock_db1,mock_db2", + [ + ( + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_on_active_database( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0), + ) + + executor.active_database = mock_db1 + assert executor.execute_command("SET", "key", "value") == "OK1" + + executor.active_database = mock_db2 + assert executor.execute_command("SET", "key", "value") == "OK2" + assert mock_ed.register_listeners.call_count == 1 + assert mock_fd.register_command_execution.call_count == 2 + + @pytest.mark.parametrize( + "mock_db,mock_db1,mock_db2", + [ + ( + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_automatically_select_active_database( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" + mock_fs.database.side_effect = [mock_db1, mock_db2] + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0), + ) + + assert executor.execute_command("SET", "key", "value") == "OK1" + mock_db1.circuit.state = CBState.OPEN + + assert executor.execute_command("SET", "key", "value") == "OK2" + assert mock_ed.register_listeners.call_count == 1 + assert mock_fs.database.call_count == 2 + assert mock_fd.register_command_execution.call_count == 2 + + @pytest.mark.parametrize( + "mock_db,mock_db1,mock_db2", + [ + ( + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_fallback_to_another_db_after_fallback_interval( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command.return_value = "OK1" + mock_db2.client.execute_command.return_value = "OK2" + mock_fs.database.side_effect = [mock_db1, mock_db2, mock_db1] + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), 0), + ) + + assert executor.execute_command("SET", "key", "value") == "OK1" + mock_db1.weight = 0.1 + sleep(0.15) + + assert executor.execute_command("SET", "key", "value") == "OK2" + mock_db1.weight = 0.7 + sleep(0.15) + + assert executor.execute_command("SET", "key", "value") == "OK1" + assert mock_ed.register_listeners.call_count == 1 + assert mock_fs.database.call_count == 3 + assert mock_fd.register_command_execution.call_count == 3 + + @pytest.mark.parametrize( + "mock_db,mock_db1,mock_db2", + [ + ( + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_command_fallback_to_another_db_after_failure_detection( + self, mock_db, mock_db1, mock_db2, mock_fs + ): + mock_db1.client.execute_command.side_effect = [ + "OK1", + ConnectionError, + ConnectionError, + ConnectionError, + "OK1", + ] + mock_db2.client.execute_command.side_effect = [ + "OK2", + ConnectionError, + ConnectionError, + ConnectionError, + ] + mock_fs.database.side_effect = [mock_db1, mock_db2, mock_db1] + threshold = 3 + fd = CommandFailureDetector(threshold, 0.0, 1) + ed = EventDispatcher() + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), threshold), + ) + fd.set_command_executor(command_executor=executor) + + assert executor.execute_command("SET", "key", "value") == "OK1" + assert executor.execute_command("SET", "key", "value") == "OK2" + assert executor.execute_command("SET", "key", "value") == "OK1" + assert mock_fs.database.call_count == 3 diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py new file mode 100644 index 0000000000..ea81f71ac9 --- /dev/null +++ b/tests/test_multidb/test_config.py @@ -0,0 +1,160 @@ +from unittest.mock import Mock + +import pytest + +from redis.connection import ConnectionPool +from redis.multidb.circuit import ( + PBCircuitBreakerAdapter, + CircuitBreaker, + DEFAULT_GRACE_PERIOD, +) +from redis.multidb.config import ( + MultiDbConfig, + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_AUTO_FALLBACK_INTERVAL, + DatabaseConfig, +) +from redis.multidb.database import Database +from redis.multidb.failure_detector import CommandFailureDetector, FailureDetector +from redis.multidb.healthcheck import EchoHealthCheck, HealthCheck +from redis.multidb.failover import WeightBasedFailoverStrategy, FailoverStrategy +from redis.retry import Retry + + +@pytest.mark.onlynoncluster +class TestMultiDbConfig: + def test_default_config(self): + db_configs = [ + DatabaseConfig( + client_kwargs={"host": "host1", "port": "port1"}, weight=1.0 + ), + DatabaseConfig( + client_kwargs={"host": "host2", "port": "port2"}, weight=0.9 + ), + DatabaseConfig( + client_kwargs={"host": "host3", "port": "port3"}, weight=0.8 + ), + ] + + config = MultiDbConfig(databases_config=db_configs) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD + assert db.client.get_retry() is not config.command_retry + i += 1 + + assert len(config.default_failure_detectors()) == 1 + assert isinstance(config.default_failure_detectors()[0], CommandFailureDetector) + assert len(config.default_health_checks()) == 1 + assert isinstance(config.default_health_checks()[0], EchoHealthCheck) + assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL + assert isinstance( + config.default_failover_strategy(), WeightBasedFailoverStrategy + ) + assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL + assert isinstance(config.command_retry, Retry) + + def test_overridden_config(self): + grace_period = 2 + mock_connection_pools = [ + Mock(spec=ConnectionPool), + Mock(spec=ConnectionPool), + Mock(spec=ConnectionPool), + ] + mock_connection_pools[0].connection_kwargs = {} + mock_connection_pools[1].connection_kwargs = {} + mock_connection_pools[2].connection_kwargs = {} + mock_cb1 = Mock(spec=CircuitBreaker) + mock_cb1.grace_period = grace_period + mock_cb2 = Mock(spec=CircuitBreaker) + mock_cb2.grace_period = grace_period + mock_cb3 = Mock(spec=CircuitBreaker) + mock_cb3.grace_period = grace_period + mock_failure_detectors = [ + Mock(spec=FailureDetector), + Mock(spec=FailureDetector), + ] + mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] + health_check_interval = 10 + mock_failover_strategy = Mock(spec=FailoverStrategy) + auto_fallback_interval = 10 + db_configs = [ + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[0]}, + weight=1.0, + circuit=mock_cb1, + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[1]}, + weight=0.9, + circuit=mock_cb2, + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[2]}, + weight=0.8, + circuit=mock_cb3, + ), + ] + + config = MultiDbConfig( + databases_config=db_configs, + failure_detectors=mock_failure_detectors, + health_checks=mock_health_checks, + health_check_interval=health_check_interval, + failover_strategy=mock_failover_strategy, + auto_fallback_interval=auto_fallback_interval, + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.client.connection_pool == mock_connection_pools[i] + assert db.circuit.grace_period == grace_period + i += 1 + + assert len(config.failure_detectors) == 2 + assert config.failure_detectors[0] == mock_failure_detectors[0] + assert config.failure_detectors[1] == mock_failure_detectors[1] + assert len(config.health_checks) == 2 + assert config.health_checks[0] == mock_health_checks[0] + assert config.health_checks[1] == mock_health_checks[1] + assert config.health_check_interval == health_check_interval + assert config.failover_strategy == mock_failover_strategy + assert config.auto_fallback_interval == auto_fallback_interval + + +class TestDatabaseConfig: + def test_default_config(self): + config = DatabaseConfig( + client_kwargs={"host": "host1", "port": "port1"}, weight=1.0 + ) + + assert config.client_kwargs == {"host": "host1", "port": "port1"} + assert config.weight == 1.0 + assert isinstance(config.default_circuit_breaker(), PBCircuitBreakerAdapter) + + def test_overridden_config(self): + mock_connection_pool = Mock(spec=ConnectionPool) + mock_circuit = Mock(spec=CircuitBreaker) + + config = DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pool}, + weight=1.0, + circuit=mock_circuit, + ) + + assert config.client_kwargs == {"connection_pool": mock_connection_pool} + assert config.weight == 1.0 + assert config.circuit == mock_circuit diff --git a/tests/test_multidb/test_failover.py b/tests/test_multidb/test_failover.py new file mode 100644 index 0000000000..60b231ab40 --- /dev/null +++ b/tests/test_multidb/test_failover.py @@ -0,0 +1,161 @@ +from time import sleep + +import pytest + +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import ( + NoValidDatabaseException, + TemporaryUnavailableException, +) +from redis.multidb.failover import ( + WeightBasedFailoverStrategy, + DefaultFailoverStrategyExecutor, +) + + +@pytest.mark.onlynoncluster +class TestWeightBasedFailoverStrategy: + @pytest.mark.parametrize( + "mock_db,mock_db1,mock_db2", + [ + ( + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ( + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + ), + ], + ids=["all closed - highest weight", "highest weight - open"], + indirect=True, + ) + def test_get_valid_database(self, mock_db, mock_db1, mock_db2): + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + + failover_strategy = WeightBasedFailoverStrategy() + failover_strategy.set_databases(databases) + + assert failover_strategy.database() == mock_db1 + + @pytest.mark.parametrize( + "mock_db,mock_db1,mock_db2", + [ + ( + {"weight": 0.2, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + {"weight": 0.5, "circuit": {"state": CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): + failover_strategy = WeightBasedFailoverStrategy() + + with pytest.raises( + NoValidDatabaseException, + match="No valid database available for communication", + ): + assert failover_strategy.database() + + +@pytest.mark.onlynoncluster +class TestDefaultStrategyExecutor: + @pytest.mark.parametrize( + "mock_db", + [ + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + ], + indirect=True, + ) + def test_execute_returns_valid_database_with_failover_attempts( + self, mock_db, mock_fs + ): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + mock_db, + ] + executor = DefaultFailoverStrategyExecutor( + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 + ) + + for i in range(failover_attempts + 1): + try: + database = executor.execute() + assert database == mock_db + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + sleep(0.11) + pass + + assert mock_fs.database.call_count == 4 + + def test_execute_throws_exception_on_attempts_exceed(self, mock_fs): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + ] + executor = DefaultFailoverStrategyExecutor( + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 + ) + + with pytest.raises(NoValidDatabaseException): + for i in range(failover_attempts + 1): + try: + executor.execute() + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + sleep(0.11) + pass + + assert mock_fs.database.call_count == 4 + + def test_execute_throws_exception_on_attempts_does_not_exceed_delay(self, mock_fs): + failover_attempts = 3 + mock_fs.database.side_effect = [ + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + NoValidDatabaseException, + ] + executor = DefaultFailoverStrategyExecutor( + mock_fs, failover_attempts=failover_attempts, failover_delay=0.1 + ) + + with pytest.raises( + TemporaryUnavailableException, + match=( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ), + ): + for i in range(failover_attempts + 1): + try: + executor.execute() + except TemporaryUnavailableException as e: + assert e.args[0] == ( + "No database connections currently available. " + "This is a temporary condition - please retry the operation." + ) + if i == failover_attempts: + raise e + + assert mock_fs.database.call_count == 4 diff --git a/tests/test_multidb/test_failure_detector.py b/tests/test_multidb/test_failure_detector.py new file mode 100644 index 0000000000..b64ff601d2 --- /dev/null +++ b/tests/test_multidb/test_failure_detector.py @@ -0,0 +1,117 @@ +from time import sleep +from unittest.mock import Mock + +import pytest + +from redis.multidb.command_executor import SyncCommandExecutor +from redis.multidb.database import Database +from redis.multidb.failure_detector import CommandFailureDetector +from redis.multidb.circuit import State as CBState +from redis.exceptions import ConnectionError + + +@pytest.mark.onlynoncluster +class TestCommandFailureDetector: + @pytest.mark.parametrize( + "min_num_failures,failure_rate_threshold,circuit_state", + [ + (2, 0.4, CBState.OPEN), + (2, 0, CBState.OPEN), + (0, 0.4, CBState.OPEN), + (3, 0.4, CBState.CLOSED), + (2, 0.41, CBState.CLOSED), + ], + ids=[ + "exceeds min num failures AND failures rate", + "exceeds min num failures AND failures rate == 0", + "min num failures == 0 AND exceeds failures rate", + "do not exceeds min num failures", + "do not exceeds failures rate", + ], + ) + def test_failure_detector_correctly_reacts_to_failures( + self, min_num_failures, failure_rate_threshold, circuit_state + ): + fd = CommandFailureDetector(min_num_failures, failure_rate_threshold) + mock_db = Mock(spec=Database) + mock_db.circuit.state = CBState.CLOSED + mock_ce = Mock(spec=SyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + + fd.register_command_execution(("GET", "key")) + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) + + fd.register_command_execution(("GET", "key")) + fd.register_command_execution(("GET", "key")) + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) + + assert mock_db.circuit.state == circuit_state + + @pytest.mark.parametrize( + "min_num_failures,failure_rate_threshold", + [ + (3, 0.0), + (3, 0.6), + ], + ids=[ + "do not exceeds min num failures, during interval", + "do not exceeds min num failures AND failure rate, during interval", + ], + ) + def test_failure_detector_do_not_open_circuit_on_interval_exceed( + self, min_num_failures, failure_rate_threshold + ): + fd = CommandFailureDetector(min_num_failures, failure_rate_threshold, 0.3) + mock_db = Mock(spec=Database) + mock_db.circuit.state = CBState.CLOSED + mock_ce = Mock(spec=SyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) + sleep(0.16) + fd.register_command_execution(("GET", "key")) + fd.register_command_execution(("GET", "key")) + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) + sleep(0.16) + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) + + assert mock_db.circuit.state == CBState.CLOSED + + # 2 more failure as last one already refreshed timer + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) + fd.register_command_execution(("GET", "key")) + fd.register_failure(Exception(), ("GET", "key")) + + assert mock_db.circuit.state == CBState.OPEN + + def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self): + fd = CommandFailureDetector(5, 1, error_types=[ConnectionError]) + mock_db = Mock(spec=Database) + mock_db.circuit.state = CBState.CLOSED + mock_ce = Mock(spec=SyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(Exception(), ("SET", "key1", "value1")) + fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + fd.register_failure(Exception(), ("SET", "key1", "value1")) + fd.register_failure(Exception(), ("SET", "key1", "value1")) + + assert mock_db.circuit.state == CBState.CLOSED + + fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + fd.register_failure(ConnectionError(), ("SET", "key1", "value1")) + + assert mock_db.circuit.state == CBState.OPEN diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py new file mode 100644 index 0000000000..fb1f1e4148 --- /dev/null +++ b/tests/test_multidb/test_healthcheck.py @@ -0,0 +1,385 @@ +from unittest.mock import MagicMock, Mock + +import pytest + +from redis.multidb.database import Database +from redis.http.http_client import HttpError +from redis.multidb.healthcheck import ( + EchoHealthCheck, + LagAwareHealthCheck, + HealthCheck, + HealthyAllPolicy, + UnhealthyDatabaseException, + HealthyMajorityPolicy, + HealthyAnyPolicy, +) +from redis.multidb.circuit import State as CBState + + +@pytest.mark.onlynoncluster +class TestHealthyAllPolicy: + def test_policy_returns_true_for_all_successful_probes(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.return_value = True + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + assert policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 3 + + def test_policy_returns_false_on_first_failed_probe(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [True, True, False] + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + assert not policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + + def test_policy_raise_unhealthy_database_exception(self): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [True, True, ConnectionError] + mock_hc2.check_health.return_value = True + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match="Unhealthy database"): + policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + + +@pytest.mark.onlynoncluster +class TestHealthyMajorityPolicy: + @pytest.mark.parametrize( + "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", + [ + (3, [True, False, False], [True, True, True], 3, 0, False), + (3, [True, True, True], [True, False, False], 3, 3, False), + (3, [True, False, True], [True, True, True], 3, 3, True), + (3, [True, True, True], [True, False, True], 3, 3, True), + (3, [True, True, False], [True, False, True], 3, 3, True), + (4, [True, True, False, False], [True, True, True, True], 4, 0, False), + (4, [True, True, True, True], [True, True, False, False], 4, 4, False), + (4, [False, True, True, True], [True, True, True, True], 4, 4, True), + (4, [True, True, True, True], [True, False, True, True], 4, 4, True), + (4, [False, True, True, True], [True, True, False, True], 4, 4, True), + ], + ids=[ + "HC1 - no majority - odd", + "HC2 - no majority - odd", + "HC1 - majority- odd", + "HC2 - majority - odd", + "HC1 + HC2 - majority - odd", + "HC1 - no majority - even", + "HC2 - no majority - even", + "HC1 - majority - even", + "HC2 - majority - even", + "HC1 + HC2 - majority - even", + ], + ) + def test_policy_returns_true_for_majority_successful_probes( + self, + probes, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result, + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyMajorityPolicy(probes, 0.01) + assert policy.execute([mock_hc1, mock_hc2], mock_db) == expected_result + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + @pytest.mark.parametrize( + "probes,hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count", + [ + (3, [True, ConnectionError, ConnectionError], [True, True, True], 3, 0), + (3, [True, True, True], [True, ConnectionError, ConnectionError], 3, 3), + ( + 4, + [True, ConnectionError, ConnectionError, True], + [True, True, True, True], + 3, + 0, + ), + ( + 4, + [True, True, True, True], + [True, ConnectionError, ConnectionError, False], + 4, + 3, + ), + ], + ids=[ + "HC1 - majority- odd", + "HC2 - majority - odd", + "HC1 - majority - even", + "HC2 - majority - even", + ], + ) + def test_policy_raise_unhealthy_database_exception_on_majority_probes_exceptions( + self, probes, hc1_side_effect, hc2_side_effect, hc1_call_count, hc2_call_count + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyAllPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match="Unhealthy database"): + policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + +@pytest.mark.onlynoncluster +class TestHealthyAnyPolicy: + @pytest.mark.parametrize( + "hc1_side_effect,hc2_side_effect,hc1_call_count,hc2_call_count,expected_result", + [ + ([False, False, False], [True, True, True], 3, 0, False), + ([False, False, True], [False, False, False], 3, 3, False), + ([False, True, True], [False, False, True], 2, 3, True), + ([True, True, True], [False, True, False], 1, 2, True), + ], + ids=[ + "HC1 - no successful", + "HC2 - no successful", + "HC1 - successful", + "HC2 - successful", + ], + ) + def test_policy_returns_true_for_any_successful_probe( + self, + hc1_side_effect, + hc2_side_effect, + hc1_call_count, + hc2_call_count, + expected_result, + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = hc1_side_effect + mock_hc2.check_health.side_effect = hc2_side_effect + mock_db = Mock(spec=Database) + + policy = HealthyAnyPolicy(3, 0.01) + assert policy.execute([mock_hc1, mock_hc2], mock_db) == expected_result + assert mock_hc1.check_health.call_count == hc1_call_count + assert mock_hc2.check_health.call_count == hc2_call_count + + def test_policy_raise_unhealthy_database_exception_if_exception_occurs_on_failed_health_check( + self, + ): + mock_hc1 = Mock(spec=HealthCheck) + mock_hc2 = Mock(spec=HealthCheck) + mock_hc1.check_health.side_effect = [False, False, ConnectionError] + mock_hc2.check_health.side_effect = [True, True, True] + mock_db = Mock(spec=Database) + + policy = HealthyAnyPolicy(3, 0.01) + with pytest.raises(UnhealthyDatabaseException, match="Unhealthy database"): + policy.execute([mock_hc1, mock_hc2], mock_db) + assert mock_hc1.check_health.call_count == 3 + assert mock_hc2.check_health.call_count == 0 + + +@pytest.mark.onlynoncluster +class TestEchoHealthCheck: + def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command.return_value = "healthcheck" + hc = EchoHealthCheck() + db = Database(mock_client, mock_cb, 0.9) + + assert hc.check_health(db) + assert mock_client.execute_command.call_count == 1 + + def test_database_is_unhealthy_on_incorrect_echo_response( + self, mock_client, mock_cb + ): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command.return_value = "wrong" + hc = EchoHealthCheck() + db = Database(mock_client, mock_cb, 0.9) + + assert not hc.check_health(db) + assert mock_client.execute_command.call_count == 1 + + def test_database_close_circuit_on_successful_healthcheck( + self, mock_client, mock_cb + ): + mock_client.execute_command.return_value = "healthcheck" + mock_cb.state = CBState.HALF_OPEN + hc = EchoHealthCheck() + db = Database(mock_client, mock_cb, 0.9) + + assert hc.check_health(db) + assert mock_client.execute_command.call_count == 1 + + +@pytest.mark.onlynoncluster +class TestLagAwareHealthCheck: + def test_database_is_healthy_when_bdb_matches_by_dns_name( + self, mock_client, mock_cb + ): + """ + Ensures health check succeeds when /v1/bdbs contains an endpoint whose dns_name + matches database host, and availability endpoint returns success. + """ + host = "db1.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + # Mock HttpClient used inside LagAwareHealthCheck + mock_http = MagicMock() + mock_http.get.side_effect = [ + # First call: list of bdbs + [ + { + "uid": "bdb-1", + "endpoints": [ + {"dns_name": host, "addr": ["10.0.0.1", "10.0.0.2"]}, + ], + } + ], + # Second call: availability check (no JSON expected) + None, + ] + + hc = LagAwareHealthCheck(rest_api_port=1234, lag_aware_tolerance=150) + # Inject our mocked http client + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + assert hc.check_health(db) is True + # Base URL must be set correctly + assert hc._http_client.base_url == "https://healthcheck.example.com:1234" + # Calls: first to list bdbs, then to availability + assert mock_http.get.call_count == 2 + first_call = mock_http.get.call_args_list[0] + second_call = mock_http.get.call_args_list[1] + assert first_call.args[0] == "/v1/bdbs" + assert ( + second_call.args[0] + == "/v1/bdbs/bdb-1/availability?extend_check=lag&availability_lag_tolerance_ms=150" + ) + assert second_call.kwargs.get("expect_json") is False + + def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb): + """ + Ensures health check succeeds when endpoint addr list contains the database host. + """ + host_ip = "203.0.113.5" + mock_client.get_connection_kwargs.return_value = {"host": host_ip} + + mock_http = MagicMock() + mock_http.get.side_effect = [ + [ + { + "uid": "bdb-42", + "endpoints": [ + {"dns_name": "not-matching.example.com", "addr": [host_ip]}, + ], + } + ], + None, + ] + + hc = LagAwareHealthCheck() + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + assert hc.check_health(db) is True + assert mock_http.get.call_count == 2 + assert ( + mock_http.get.call_args_list[1].args[0] + == "/v1/bdbs/bdb-42/availability?extend_check=lag&availability_lag_tolerance_ms=5000" + ) + + def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): + """ + Ensures health check raises ValueError when there's no bdb matching the database host. + """ + host = "db2.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = MagicMock() + # Return bdbs that do not match host by dns_name nor addr + mock_http.get.return_value = [ + { + "uid": "a", + "endpoints": [{"dns_name": "other.example.com", "addr": ["10.0.0.9"]}], + }, + { + "uid": "b", + "endpoints": [ + {"dns_name": "another.example.com", "addr": ["10.0.0.10"]} + ], + }, + ] + + hc = LagAwareHealthCheck() + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + with pytest.raises(ValueError, match="Could not find a matching bdb"): + hc.check_health(db) + + # Only the listing call should have happened + mock_http.get.assert_called_once_with("/v1/bdbs") + + def test_propagates_http_error_from_availability(self, mock_client, mock_cb): + """ + Ensures that any HTTP error raised by the availability endpoint is propagated. + """ + host = "db3.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = MagicMock() + # First: list bdbs -> match by dns_name + mock_http.get.side_effect = [ + [{"uid": "bdb-err", "endpoints": [{"dns_name": host, "addr": []}]}], + # Second: availability -> raise HttpError + HttpError( + url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", + status=503, + message="busy", + ), + ] + + hc = LagAwareHealthCheck() + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + with pytest.raises(HttpError, match="busy") as e: + hc.check_health(db) + assert e.status == 503 + + # Ensure both calls were attempted + assert mock_http.get.call_count == 2 diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py new file mode 100644 index 0000000000..c3a494dd95 --- /dev/null +++ b/tests/test_multidb/test_pipeline.py @@ -0,0 +1,382 @@ +from time import sleep +from unittest.mock import patch, Mock + +import pybreaker +import pytest + +from redis.client import Pipeline +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.client import MultiDBClient +from redis.multidb.failover import ( + WeightBasedFailoverStrategy, +) +from redis.multidb.healthcheck import EchoHealthCheck +from tests.test_multidb.conftest import create_weighted_list + + +def mock_pipe() -> Pipeline: + mock_pipe = Mock(spec=Pipeline) + mock_pipe.__enter__ = Mock(return_value=mock_pipe) + mock_pipe.__exit__ = Mock(return_value=None) + return mock_pipe + + +@pytest.mark.onlynoncluster +class TestPipeline: + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_executes_pipeline_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + pipe = mock_pipe() + pipe.execute.return_value = ["OK1", "value1"] + mock_db1.client.pipeline.return_value = pipe + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + pipe = client.pipeline() + pipe.set("key1", "value1") + pipe.get("key1") + + assert pipe.execute() == ["OK1", "value1"] + assert mock_hc.check_health.call_count == 9 + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_execute_pipeline_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + pipe = mock_pipe() + pipe.execute.return_value = ["OK1", "value1"] + mock_db1.client.pipeline.return_value = pipe + + mock_hc.check_health.side_effect = [ + False, + True, + True, + True, + True, + True, + True, + ] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with client.pipeline() as pipe: + pipe.set("key1", "value1") + pipe.get("key1") + + assert pipe.execute() == ["OK1", "value1"] + assert mock_hc.check_health.call_count == 7 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_pipeline_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "error", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "error", + "error", + "healthcheck", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "error", + "error", + ] + + pipe = mock_pipe() + pipe.execute.return_value = ["OK", "value"] + mock_db.client.pipeline.return_value = pipe + + pipe1 = mock_pipe() + pipe1.execute.return_value = ["OK1", "value"] + mock_db1.client.pipeline.return_value = pipe1 + + pipe2 = mock_pipe() + pipe2.execute.return_value = ["OK2", "value"] + mock_db2.client.pipeline.return_value = pipe2 + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() + + client = MultiDBClient(mock_multi_db_config) + + with client.pipeline() as pipe: + pipe.set("key1", "value") + pipe.get("key1") + + assert pipe.execute() == ["OK1", "value"] + + sleep(0.15) + + with client.pipeline() as pipe: + pipe.set("key1", "value") + pipe.get("key1") + + assert pipe.execute() == ["OK2", "value"] + + sleep(0.1) + + with client.pipeline() as pipe: + pipe.set("key1", "value") + pipe.get("key1") + + assert pipe.execute() == ["OK", "value"] + + sleep(0.1) + + with client.pipeline() as pipe: + pipe.set("key1", "value") + pipe.get("key1") + + assert pipe.execute() == ["OK1", "value"] + + +class TestTransaction: + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_executes_transaction_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.transaction.return_value = ["OK1", "value1"] + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + def callback(pipe: Pipeline): + pipe.set("key1", "value1") + pipe.get("key1") + + assert client.transaction(callback) == ["OK1", "value1"] + assert mock_hc.check_health.call_count == 9 + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.OPEN}}, + ), + ], + indirect=True, + ) + def test_execute_transaction_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, "default_health_checks", return_value=[mock_hc] + ), + ): + mock_db1.client.transaction.return_value = ["OK1", "value1"] + + mock_hc.check_health.side_effect = [ + False, + True, + True, + True, + True, + True, + True, + ] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + def callback(pipe: Pipeline): + pipe.set("key1", "value1") + pipe.get("key1") + + assert client.transaction(callback) == ["OK1", "value1"] + assert mock_hc.check_health.call_count == 7 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.parametrize( + "mock_multi_db_config,mock_db, mock_db1, mock_db2", + [ + ( + {"health_check_probes": 1}, + {"weight": 0.2, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.7, "circuit": {"state": CBState.CLOSED}}, + {"weight": 0.5, "circuit": {"state": CBState.CLOSED}}, + ), + ], + indirect=True, + ) + def test_execute_transaction_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with ( + patch.object(mock_multi_db_config, "databases", return_value=databases), + patch.object( + mock_multi_db_config, + "default_health_checks", + return_value=[EchoHealthCheck()], + ), + ): + mock_db.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "healthcheck", + "error", + ] + mock_db1.client.execute_command.side_effect = [ + "healthcheck", + "error", + "error", + "healthcheck", + ] + mock_db2.client.execute_command.side_effect = [ + "healthcheck", + "healthcheck", + "error", + "error", + ] + + mock_db.client.transaction.return_value = ["OK", "value"] + mock_db1.client.transaction.return_value = ["OK1", "value"] + mock_db2.client.transaction.return_value = ["OK2", "value"] + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy() + + client = MultiDBClient(mock_multi_db_config) + + def callback(pipe: Pipeline): + pipe.set("key1", "value1") + pipe.get("key1") + + assert client.transaction(callback) == ["OK1", "value"] + sleep(0.15) + assert client.transaction(callback) == ["OK2", "value"] + sleep(0.1) + assert client.transaction(callback) == ["OK", "value"] + sleep(0.1) + assert client.transaction(callback) == ["OK1", "value"] diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 531d09baa4..5f568aa84e 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -1,11 +1,24 @@ import json import logging import os +import re from typing import Optional from urllib.parse import urlparse + import pytest -from redis.backoff import ExponentialWithJitterBackoff, NoBackoff +from redis.backoff import NoBackoff, ExponentialBackoff +from redis.event import EventDispatcher, EventListenerInterface +from redis.multidb.client import MultiDBClient +from redis.multidb.config import ( + DatabaseConfig, + MultiDbConfig, + DEFAULT_HEALTH_CHECK_INTERVAL, +) +from redis.multidb.event import ActiveDatabaseChanged +from redis.multidb.failure_detector import DEFAULT_MIN_NUM_FAILURES +from redis.multidb.healthcheck import DEFAULT_HEALTH_CHECK_DELAY +from redis.backoff import ExponentialWithJitterBackoff from redis.client import Redis from redis.maint_notifications import EndpointType, MaintNotificationsConfig from redis.retry import Retry @@ -17,6 +30,14 @@ DEFAULT_ENDPOINT_NAME = "m-standard" +class CheckActiveDatabaseChangedListener(EventListenerInterface): + def __init__(self): + self.is_changed_flag = False + + def listen(self, event: ActiveDatabaseChanged): + self.is_changed_flag = True + + @pytest.fixture() def endpoint_name(request): return request.config.getoption("--endpoint-name") or os.getenv( @@ -24,8 +45,7 @@ def endpoint_name(request): ) -@pytest.fixture() -def endpoints_config(endpoint_name: str): +def get_endpoints_config(endpoint_name: str): endpoints_config = os.getenv("REDIS_ENDPOINTS_CONFIG_PATH", None) if not (endpoints_config and os.path.exists(endpoints_config)): @@ -42,12 +62,108 @@ def endpoints_config(endpoint_name: str): ) from e +@pytest.fixture() +def endpoints_config(endpoint_name: str): + return get_endpoints_config(endpoint_name) + + @pytest.fixture() def fault_injector_client(): url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") return FaultInjectorClient(url) +@pytest.fixture() +def r_multi_db( + request, +) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: + client_class = request.param.get("client_class", Redis) + + if client_class == Redis: + endpoint_config = get_endpoints_config("re-active-active") + else: + endpoint_config = get_endpoints_config("re-active-active-oss-cluster") + + username = endpoint_config.get("username", None) + password = endpoint_config.get("password", None) + min_num_failures = request.param.get("min_num_failures", DEFAULT_MIN_NUM_FAILURES) + command_retry = request.param.get( + "command_retry", Retry(ExponentialBackoff(cap=0.1, base=0.01), retries=10) + ) + + # Retry configuration different for health checks as initial health check require more time in case + # if infrastructure wasn't restored from the previous test. + health_check_interval = request.param.get( + "health_check_interval", DEFAULT_HEALTH_CHECK_INTERVAL + ) + health_check_delay = request.param.get( + "health_check_delay", DEFAULT_HEALTH_CHECK_DELAY + ) + event_dispatcher = EventDispatcher() + listener = CheckActiveDatabaseChangedListener() + event_dispatcher.register_listeners( + { + ActiveDatabaseChanged: [listener], + } + ) + db_configs = [] + + db_config = DatabaseConfig( + weight=1.0, + from_url=endpoint_config["endpoints"][0], + client_kwargs={ + "username": username, + "password": password, + "decode_responses": True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config["endpoints"][0]), + ) + db_configs.append(db_config) + + db_config1 = DatabaseConfig( + weight=0.9, + from_url=endpoint_config["endpoints"][1], + client_kwargs={ + "username": username, + "password": password, + "decode_responses": True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config["endpoints"][1]), + ) + db_configs.append(db_config1) + + config = MultiDbConfig( + client_class=client_class, + databases_config=db_configs, + command_retry=command_retry, + min_num_failures=min_num_failures, + health_check_probes=3, + health_check_interval=health_check_interval, + event_dispatcher=event_dispatcher, + health_check_delay=health_check_delay, + ) + + return MultiDBClient(config), listener, endpoint_config + + +def extract_cluster_fqdn(url): + """ + Extract Cluster FQDN from Redis URL + """ + # Parse the URL + parsed = urlparse(url) + + # Extract hostname and port + hostname = parsed.hostname + + # Remove the 'redis-XXXX.' prefix using regex + # This pattern matches 'redis-' followed by digits and a dot + cleaned_hostname = re.sub(r"^redis-\d+\.", "", hostname) + + # Reconstruct the URL + return f"https://{cleaned_hostname}" + + @pytest.fixture() def client_maint_notifications(endpoints_config): return _get_client_maint_notifications(endpoints_config) diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py new file mode 100644 index 0000000000..59524ab5c1 --- /dev/null +++ b/tests/test_scenario/test_active_active.py @@ -0,0 +1,460 @@ +import json +import logging +import os +import threading +from time import sleep + +import pytest + +from redis import Redis, RedisCluster +from redis.backoff import ConstantBackoff +from redis.client import Pipeline +from redis.multidb.exception import TemporaryUnavailableException +from redis.multidb.failover import DEFAULT_FAILOVER_ATTEMPTS, DEFAULT_FAILOVER_DELAY +from redis.multidb.healthcheck import LagAwareHealthCheck +from redis.retry import Retry +from redis.utils import dummy_fail +from tests.test_scenario.fault_injector_client import ActionRequest, ActionType + +logger = logging.getLogger(__name__) + + +def trigger_network_failure_action( + fault_injector_client, config, event: threading.Event = None +): + action_request = ActionRequest( + action_type=ActionType.NETWORK_FAILURE, + parameters={"bdb_id": config["bdb_id"], "delay": 3, "cluster_index": 0}, + ) + + result = fault_injector_client.trigger_action(action_request) + status_result = fault_injector_client.get_action_status(result["action_id"]) + + while status_result["status"] != "success": + sleep(0.1) + status_result = fault_injector_client.get_action_status(result["action_id"]) + logger.info( + f"Waiting for action to complete. Status: {status_result['status']}" + ) + + if event: + event.set() + + logger.info(f"Action completed. Status: {status_result['status']}") + + +class TestActiveActive: + def teardown_method(self, method): + # Timeout so the cluster could recover from network failure. + sleep(10) + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, + ], + ids=["standalone", "cluster"], + indirect=True, + ) + @pytest.mark.timeout(100) + def test_multi_db_client_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): + r_multi_db, listener, config = r_multi_db + + # Handle unavailable databases from previous test. + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), + ) + + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client, config, event), + ) + + # Client initialized on the first command. + retry.call_with_retry( + lambda: r_multi_db.set("key", "value"), lambda _: dummy_fail() + ) + thread.start() + + # Execute commands before network failure + while not event.is_set(): + assert ( + retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail() + ) + == "value" + ) + sleep(0.5) + + # Execute commands until database failover + while not listener.is_changed_flag: + assert ( + retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail() + ) + == "value" + ) + sleep(0.5) + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "min_num_failures": 2, "health_check_interval": 20}, + { + "client_class": RedisCluster, + "min_num_failures": 2, + "health_check_interval": 20, + }, + ], + ids=["standalone", "cluster"], + indirect=True, + ) + @pytest.mark.timeout(100) + def test_multi_db_client_uses_lag_aware_health_check( + self, r_multi_db, fault_injector_client + ): + r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), + ) + + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client, config, event), + ) + + env0_username = os.getenv("ENV0_USERNAME") + env0_password = os.getenv("ENV0_PASSWORD") + + # Adding additional health check to the client. + r_multi_db.add_health_check( + LagAwareHealthCheck( + verify_tls=False, + auth_basic=(env0_username, env0_password), + lag_aware_tolerance=10000, + ) + ) + + # Client initialized on the first command. + retry.call_with_retry( + lambda: r_multi_db.set("key", "value"), lambda _: dummy_fail() + ) + thread.start() + + # Execute commands before network failure + while not event.is_set(): + assert ( + retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail() + ) + == "value" + ) + sleep(0.5) + + # Execute commands after network failure + while not listener.is_changed_flag: + assert ( + retry.call_with_retry( + lambda: r_multi_db.get("key"), lambda _: dummy_fail() + ) + == "value" + ) + sleep(0.5) + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, + ], + ids=["standalone", "cluster"], + indirect=True, + ) + @pytest.mark.timeout(100) + def test_context_manager_pipeline_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): + r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), + ) + + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client, config, event), + ) + + def callback(): + with r_multi_db.pipeline() as pipe: + pipe.set("{hash}key1", "value1") + pipe.set("{hash}key2", "value2") + pipe.set("{hash}key3", "value3") + pipe.get("{hash}key1") + pipe.get("{hash}key2") + pipe.get("{hash}key3") + assert pipe.execute() == [ + True, + True, + True, + "value1", + "value2", + "value3", + ] + + # Client initialized on first pipe execution. + retry.call_with_retry(lambda: callback(), lambda _: dummy_fail()) + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + retry.call_with_retry(lambda: callback(), lambda _: dummy_fail()) + sleep(0.5) + + # Execute pipeline until database failover + for _ in range(5): + retry.call_with_retry(lambda: callback(), lambda _: dummy_fail()) + sleep(0.5) + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, + ], + ids=["standalone", "cluster"], + indirect=True, + ) + @pytest.mark.timeout(100) + def test_chaining_pipeline_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): + r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), + ) + + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client, config, event), + ) + + def callback(): + pipe = r_multi_db.pipeline() + pipe.set("{hash}key1", "value1") + pipe.set("{hash}key2", "value2") + pipe.set("{hash}key3", "value3") + pipe.get("{hash}key1") + pipe.get("{hash}key2") + pipe.get("{hash}key3") + assert pipe.execute() == [True, True, True, "value1", "value2", "value3"] + + # Client initialized on first pipe execution. + retry.call_with_retry(lambda: callback(), lambda _: dummy_fail()) + + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + retry.call_with_retry(lambda: callback(), lambda _: dummy_fail()) + sleep(0.5) + + # Execute pipeline until database failover + for _ in range(5): + retry.call_with_retry(lambda: callback(), lambda _: dummy_fail()) + sleep(0.5) + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, + ], + ids=["standalone", "cluster"], + indirect=True, + ) + @pytest.mark.timeout(100) + def test_transaction_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): + r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), + ) + + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client, config, event), + ) + + def callback(pipe: Pipeline): + pipe.set("{hash}key1", "value1") + pipe.set("{hash}key2", "value2") + pipe.set("{hash}key3", "value3") + pipe.get("{hash}key1") + pipe.get("{hash}key2") + pipe.get("{hash}key3") + + # Client initialized on first transaction execution. + retry.call_with_retry( + lambda: r_multi_db.transaction(callback), lambda _: dummy_fail() + ) + thread.start() + + # Execute transaction before network failure + while not event.is_set(): + retry.call_with_retry( + lambda: r_multi_db.transaction(callback), lambda _: dummy_fail() + ) + sleep(0.5) + + # Execute transaction until database failover + while not listener.is_changed_flag: + retry.call_with_retry( + lambda: r_multi_db.transaction(callback), lambda _: dummy_fail() + ) + sleep(0.5) + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, + ], + ids=["standalone", "cluster"], + indirect=True, + ) + @pytest.mark.timeout(100) + def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), + ) + + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client, config, event), + ) + data = json.dumps({"message": "test"}) + messages_count = 0 + + def handler(message): + nonlocal messages_count + messages_count += 1 + + pubsub = r_multi_db.pubsub() + + # Assign a handler and run in a separate thread. + retry.call_with_retry( + lambda: pubsub.subscribe(**{"test-channel": handler}), + lambda _: dummy_fail(), + ) + pubsub_thread = pubsub.run_in_thread(sleep_time=0.1, daemon=True) + thread.start() + + # Execute publish before network failure + while not event.is_set(): + retry.call_with_retry( + lambda: r_multi_db.publish("test-channel", data), lambda _: dummy_fail() + ) + sleep(0.5) + + # Execute publish until database failover + while not listener.is_changed_flag: + retry.call_with_retry( + lambda: r_multi_db.publish("test-channel", data), lambda _: dummy_fail() + ) + sleep(0.5) + + pubsub_thread.stop() + assert messages_count > 2 + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "min_num_failures": 2}, + {"client_class": RedisCluster, "min_num_failures": 2}, + ], + ids=["standalone", "cluster"], + indirect=True, + ) + @pytest.mark.timeout(100) + def test_sharded_pubsub_failover_to_another_db( + self, r_multi_db, fault_injector_client + ): + r_multi_db, listener, config = r_multi_db + retry = Retry( + supported_errors=(TemporaryUnavailableException,), + retries=DEFAULT_FAILOVER_ATTEMPTS, + backoff=ConstantBackoff(backoff=DEFAULT_FAILOVER_DELAY), + ) + + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client, config, event), + ) + data = json.dumps({"message": "test"}) + messages_count = 0 + + def handler(message): + nonlocal messages_count + messages_count += 1 + + pubsub = r_multi_db.pubsub() + + # Assign a handler and run in a separate thread. + retry.call_with_retry( + lambda: pubsub.ssubscribe(**{"test-channel": handler}), + lambda _: dummy_fail(), + ) + pubsub_thread = pubsub.run_in_thread( + sleep_time=0.1, daemon=True, sharded_pubsub=True + ) + thread.start() + + # Execute publish before network failure + while not event.is_set(): + retry.call_with_retry( + lambda: r_multi_db.spublish("test-channel", data), + lambda _: dummy_fail(), + ) + sleep(0.5) + + # Execute publish until database failover + while not listener.is_changed_flag: + retry.call_with_retry( + lambda: r_multi_db.spublish("test-channel", data), + lambda _: dummy_fail(), + ) + sleep(0.5) + + pubsub_thread.stop() + assert messages_count > 2