diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 2ea572c93e..805941bbaf 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -55,12 +55,14 @@ def client_maint_events(endpoints_config): def _get_client_maint_events( endpoints_config, + protocol: int = 3, enable_maintenance_events: bool = True, endpoint_type: Optional[EndpointType] = None, enable_relax_timeout: bool = True, enable_proactive_reconnect: bool = True, disable_retries: bool = False, socket_timeout: Optional[float] = None, + host_config: Optional[str] = None, ): """Create Redis client with maintenance events enabled.""" @@ -74,11 +76,9 @@ def _get_client_maint_events( raise ValueError("No endpoints found in configuration") parsed = urlparse(endpoints[0]) - host = parsed.hostname + host = parsed.hostname if host_config is None else host_config port = parsed.port - tls_enabled = True if parsed.scheme == "rediss" else False - if not host: raise ValueError(f"Could not parse host from endpoint URL: {endpoints[0]}") @@ -99,6 +99,9 @@ def _get_client_maint_events( else: retry = Retry(backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3) + tls_enabled = True if parsed.scheme == "rediss" else False + logging.info(f"TLS enabled: {tls_enabled}") + client = Redis( host=host, port=port, @@ -108,7 +111,7 @@ def _get_client_maint_events( ssl=tls_enabled, ssl_cert_reqs="none", ssl_check_hostname=False, - protocol=3, # RESP3 required for push notifications + protocol=protocol, # RESP3 required for push notifications maintenance_events_config=maintenance_config, retry=retry, ) diff --git a/tests/test_scenario/fault_injector_client.py b/tests/test_scenario/fault_injector_client.py index fffda9ac01..8bce3a19e7 100644 --- a/tests/test_scenario/fault_injector_client.py +++ b/tests/test_scenario/fault_injector_client.py @@ -2,6 +2,7 @@ import logging import time import urllib.request +import urllib.error from typing import Dict, Any, Optional, Union from enum import Enum @@ -94,7 +95,7 @@ def get_action_status(self, action_id: str) -> Dict[str, Any]: return self._make_request("GET", f"/action/{action_id}") def execute_rladmin_command( - self, command: str, bdb_id: str = None + self, command: str, bdb_id: Optional[str] = None ) -> Dict[str, Any]: """Execute rladmin command directly as string""" url = f"{self.base_url}/rladmin" @@ -146,4 +147,4 @@ def get_operation_result( logging.warning(f"Error checking operation status: {e}") time.sleep(check_interval) else: - raise TimeoutError(f"Timeout waiting for operation {action_id}") + pytest.fail(f"Timeout waiting for operation {action_id}") diff --git a/tests/test_scenario/hitless_upgrade_helpers.py b/tests/test_scenario/hitless_upgrade_helpers.py index 3997d557cd..0abc0e067c 100644 --- a/tests/test_scenario/hitless_upgrade_helpers.py +++ b/tests/test_scenario/hitless_upgrade_helpers.py @@ -17,6 +17,7 @@ class ClientValidations: def wait_push_notification( redis_client: Redis, timeout: int = 120, + fail_on_timeout: bool = True, connection: Optional[Connection] = None, ): """Wait for a push notification to be received.""" @@ -35,11 +36,15 @@ def wait_push_notification( logging.debug( f"Push notification has been received. Response: {push_response}" ) + if test_conn.should_reconnect(): + logging.debug("Connection is marked for reconnect") return except Exception as e: logging.error(f"Error reading push notification: {e}") break time.sleep(check_interval) + if fail_on_timeout: + pytest.fail("Timeout waiting for push notification") finally: # Release the connection back to the pool try: @@ -215,6 +220,40 @@ def find_endpoint_for_bind( raise ValueError(f"No endpoint ID for {endpoint_name} found in cluster status") + @staticmethod + def execute_failover( + fault_injector: FaultInjectorClient, + endpoint_config: Dict[str, Any], + timeout: int = 60, + ) -> Dict[str, Any]: + """Execute failover command and wait for completion.""" + + try: + bdb_id = endpoint_config.get("bdb_id") + failover_action = ActionRequest( + action_type=ActionType.FAILOVER, + parameters={ + "bdb_id": bdb_id, + }, + ) + trigger_action_result = fault_injector.trigger_action(failover_action) + action_id = trigger_action_result.get("action_id") + if not action_id: + raise ValueError( + f"Failed to trigger fail over action for bdb_id {bdb_id}: {trigger_action_result}" + ) + + action_status_check_response = fault_injector.get_operation_result( + action_id, timeout=timeout + ) + logging.info( + f"Completed cluster nodes info reading: {action_status_check_response}" + ) + return action_status_check_response + + except Exception as e: + pytest.fail(f"Failed to get cluster nodes info: {e}") + @staticmethod def execute_rladmin_migrate( fault_injector: FaultInjectorClient, diff --git a/tests/test_scenario/test_hitless_upgrade.py b/tests/test_scenario/test_hitless_upgrade.py index c902d7d37f..f23be4d4b2 100644 --- a/tests/test_scenario/test_hitless_upgrade.py +++ b/tests/test_scenario/test_hitless_upgrade.py @@ -1,16 +1,22 @@ """Tests for Redis Enterprise moving push notifications with real cluster operations.""" +from concurrent.futures import ThreadPoolExecutor import logging from queue import Queue from threading import Thread import threading import time -from typing import Any, Dict +from typing import Any, Dict, List import pytest from redis import Redis -from redis.maintenance_events import EndpointType, MaintenanceState +from redis.connection import ConnectionInterface +from redis.maintenance_events import ( + EndpointType, + MaintenanceEventsConfig, + MaintenanceState, +) from tests.test_scenario.conftest import ( CLIENT_TIMEOUT, RELAX_TIMEOUT, @@ -25,14 +31,15 @@ ) logging.basicConfig( - level=logging.DEBUG, + level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", - filemode="w", - filename="./test_hitless_upgrade.log", ) -BIND_TIMEOUT = 60 -MIGRATE_TIMEOUT = 120 +BIND_TIMEOUT = 30 +MIGRATE_TIMEOUT = 60 +FAILOVER_TIMEOUT = 15 + +DEFAULT_BIND_TTL = 15 class TestPushNotifications: @@ -50,6 +57,7 @@ def setup_and_cleanup( endpoint_name: str, ): # Initialize cleanup flags first to ensure they exist even if setup fails + self._failover_executed = False self._migration_executed = False self._bind_executed = False self.target_node = None @@ -93,6 +101,13 @@ def setup_and_cleanup( logging.error(f"Failed to close client: {e}") # Only attempt cleanup if we have the necessary attributes and they were executed + if self._failover_executed: + try: + self._execute_failover(fault_injector_client, endpoints_config) + logging.info("Failover cleanup completed") + except Exception as e: + logging.error(f"Failed to revert failover: {e}") + if self._migration_executed: try: if self.target_node and self.empty_node: @@ -118,6 +133,18 @@ def setup_and_cleanup( logging.info("Cleanup finished") + def _execute_failover( + self, + fault_injector_client: FaultInjectorClient, + endpoints_config: Dict[str, Any], + ): + failover_result = ClusterOperations.execute_failover( + fault_injector_client, endpoints_config + ) + self._failover_executed = True + + logging.debug(f"Failover result: {failover_result}") + def _execute_migration( self, fault_injector_client: FaultInjectorClient, @@ -176,7 +203,7 @@ def _execute_migrate_bind_flow( endpoint_id=endpoint_id, ) - def _get_all_connections_in_pool(self, client: Redis): + def _get_all_connections_in_pool(self, client: Redis) -> List[ConnectionInterface]: connections = [] if hasattr(client.connection_pool, "_available_connections"): for conn in client.connection_pool._available_connections: @@ -226,6 +253,10 @@ def _validate_moving_state( or ( configured_endpoint_type != EndpointType.NONE and conn.host != conn.orig_host_address + and ( + configured_endpoint_type + == MaintenanceEventsConfig().get_endpoint_type(conn.host, conn) + ) ) ) if ( @@ -296,6 +327,51 @@ def _validate_default_notif_disabled_state( matching_conns_count += 1 assert matching_conns_count == expected_matching_conns_count + @pytest.mark.timeout(300) # 5 minutes timeout for this test + def test_receive_failing_over_and_failed_over_push_notification( + self, + client_maint_events: Redis, + fault_injector_client: FaultInjectorClient, + endpoints_config: Dict[str, Any], + ): + """ + Test the push notifications are received when executing cluster operations. + + """ + logging.info("Creating one connection in the pool.") + conn = client_maint_events.connection_pool.get_connection() + + logging.info("Executing failover command...") + failover_thread = Thread( + target=self._execute_failover, + name="failover_thread", + args=(fault_injector_client, endpoints_config), + ) + failover_thread.start() + + logging.info("Waiting for FAILING_OVER push notifications...") + ClientValidations.wait_push_notification( + client_maint_events, timeout=FAILOVER_TIMEOUT, connection=conn + ) + + logging.info("Validating connection maintenance state...") + assert conn.maintenance_state == MaintenanceState.MAINTENANCE + assert conn._sock.gettimeout() == RELAX_TIMEOUT + + logging.info("Waiting for FAILED_OVER push notifications...") + ClientValidations.wait_push_notification( + client_maint_events, timeout=FAILOVER_TIMEOUT, connection=conn + ) + + logging.info("Validating connection default states is restored...") + assert conn.maintenance_state == MaintenanceState.NONE + assert conn._sock.gettimeout() == CLIENT_TIMEOUT + + logging.info("Releasing connection back to the pool...") + client_maint_events.connection_pool.release(conn) + + failover_thread.join() + @pytest.mark.timeout(300) # 5 minutes timeout for this test def test_receive_migrating_and_moving_push_notification( self, @@ -494,7 +570,7 @@ def test_timeout_handling_during_migrating_and_moving( EndpointType.NONE, ], ) - def test_new_connection_handling_during_migrating_and_moving( + def test_connection_handling_during_moving( self, endpoint_type: EndpointType, fault_injector_client: FaultInjectorClient, @@ -529,32 +605,12 @@ def test_new_connection_handling_during_migrating_and_moving( self._validate_maintenance_state(client, expected_matching_conns_count=1) - # validate that new connections will also receive the moving event - logging.info( - "Creating second connection in the pool" - " and expect it to receive the migrating as well." - ) - - second_connection = client.connection_pool.get_connection() - ClientValidations.wait_push_notification( - client, timeout=MIGRATE_TIMEOUT, connection=second_connection - ) - - logging.info( - "Validating connection states after MIGRATING for both connections ..." - ) - self._validate_maintenance_state(client, expected_matching_conns_count=2) - - logging.info("Waiting for MIGRATED push notifications on both connections ...") + logging.info("Waiting for MIGRATED push notification ...") ClientValidations.wait_push_notification( client, timeout=MIGRATE_TIMEOUT, connection=first_conn ) - ClientValidations.wait_push_notification( - client, timeout=MIGRATE_TIMEOUT, connection=second_connection - ) client.connection_pool.release(first_conn) - client.connection_pool.release(second_connection) migrate_thread.join() @@ -608,6 +664,259 @@ def test_new_connection_handling_during_migrating_and_moving( self._validate_default_state(client, expected_matching_conns_count=3) bind_thread.join() + @pytest.mark.timeout(300) # 5 minutes timeout + def test_old_connection_shutdown_during_moving( + self, + fault_injector_client: FaultInjectorClient, + endpoints_config: Dict[str, Any], + ): + # it is better to use ip for this test - enables validation that + # the connection is disconnected from the original address + # and connected to the new address + endpoint_type = EndpointType.EXTERNAL_IP + logging.info("Testing old connection shutdown during MOVING") + client = _get_client_maint_events( + endpoints_config=endpoints_config, endpoint_type=endpoint_type + ) + + logging.info("Starting migration ...") + migrate_thread = Thread( + target=self._execute_migration, + name="migrate_thread", + args=( + fault_injector_client, + endpoints_config, + self.target_node, + self.empty_node, + ), + ) + migrate_thread.start() + + logging.info("Waiting for MIGRATING push notifications...") + ClientValidations.wait_push_notification(client, timeout=MIGRATE_TIMEOUT) + self._validate_maintenance_state(client, expected_matching_conns_count=1) + + logging.info("Waiting for MIGRATED push notification ...") + ClientValidations.wait_push_notification(client, timeout=MIGRATE_TIMEOUT) + self._validate_default_state(client, expected_matching_conns_count=1) + migrate_thread.join() + + moving_event = threading.Event() + + def execute_commands(moving_event: threading.Event, errors: Queue): + while not moving_event.is_set(): + try: + client.set("key", "value") + client.get("key") + except Exception as e: + errors.put( + f"Command failed in thread {threading.current_thread().name}: {e}" + ) + + logging.info("Starting rebind...") + bind_thread = Thread( + target=self._execute_bind, + name="bind_thread", + args=(fault_injector_client, endpoints_config, self.endpoint_id), + ) + bind_thread.start() + + errors = Queue() + threads_count = 10 + futures = [] + + logging.info(f"Starting {threads_count} command execution threads...") + # Start the worker pool and submit N identical worker tasks + with ThreadPoolExecutor( + max_workers=threads_count, thread_name_prefix="command_execution_thread" + ) as executor: + futures = [ + executor.submit(execute_commands, moving_event, errors) + for _ in range(threads_count) + ] + + logging.info("Waiting for MOVING push notification ...") + # this will consume the notification in one of the connections + # and will handle the states of the rest + ClientValidations.wait_push_notification(client, timeout=BIND_TIMEOUT) + # set the event to stop the command execution threads + moving_event.set() + + # Wait for all workers to finish and propagate any exceptions + for f in futures: + f.result() + + # validate that all connections are either disconnected + # or connected to the new address + connections = self._get_all_connections_in_pool(client) + for conn in connections: + if conn._sock is not None: + assert conn.get_resolved_ip() == conn.host + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn._sock.gettimeout() == RELAX_TIMEOUT + assert conn.host != conn.orig_host_address + assert not conn.should_reconnect() + else: + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn.socket_timeout == RELAX_TIMEOUT + assert conn.host != conn.orig_host_address + assert not conn.should_reconnect() + + # validate no errors were raised in the command execution threads + assert errors.empty(), f"Errors occurred in threads: {errors.queue}" + + logging.info("Waiting for moving ttl to expire") + time.sleep(DEFAULT_BIND_TTL) + bind_thread.join() + + @pytest.mark.timeout(300) # 5 minutes timeout + def test_new_connections_receive_moving( + self, + client_maint_events: Redis, + fault_injector_client: FaultInjectorClient, + endpoints_config: Dict[str, Any], + ): + logging.info("Creating one connection in the pool.") + first_conn = client_maint_events.connection_pool.get_connection() + + logging.info("Executing rladmin migrate command...") + migrate_thread = Thread( + target=self._execute_migration, + name="migrate_thread", + args=( + fault_injector_client, + endpoints_config, + self.target_node, + self.empty_node, + ), + ) + migrate_thread.start() + + logging.info("Waiting for MIGRATING push notifications...") + # this will consume the notification in the provided connection + ClientValidations.wait_push_notification( + client_maint_events, timeout=MIGRATE_TIMEOUT, connection=first_conn + ) + + self._validate_maintenance_state( + client_maint_events, expected_matching_conns_count=1 + ) + + logging.info("Waiting for MIGRATED push notifications on both connections ...") + ClientValidations.wait_push_notification( + client_maint_events, timeout=MIGRATE_TIMEOUT, connection=first_conn + ) + + migrate_thread.join() + + logging.info("Executing rladmin bind endpoint command...") + + bind_thread = Thread( + target=self._execute_bind, + name="bind_thread", + args=(fault_injector_client, endpoints_config, self.endpoint_id), + ) + bind_thread.start() + + logging.info("Waiting for MOVING push notifications on random connection ...") + ClientValidations.wait_push_notification( + client_maint_events, timeout=BIND_TIMEOUT, connection=first_conn + ) + + old_address = first_conn._sock.getpeername()[0] + logging.info(f"The node address before bind: {old_address}") + logging.info( + "Creating new client to connect to the same node - new connections to this node should receive the moving event..." + ) + + endpoint_type = EndpointType.EXTERNAL_IP + # create new client with new pool that should also receive the moving event + new_client = _get_client_maint_events( + endpoints_config=endpoints_config, + endpoint_type=endpoint_type, + host_config=old_address, + ) + + # the moving notification will be consumed as + # part of the client connection setup, so we don't need + # to wait for it explicitly with wait_push_notification + logging.info( + "Creating one connection in the new pool that should receive the moving event." + ) + new_client_conn = new_client.connection_pool.get_connection() + + logging.info("Validating connections states during MOVING ...") + self._validate_moving_state( + new_client, + endpoint_type, + expected_matching_connected_conns_count=1, + expected_matching_disconnected_conns_count=0, + ) + + logging.info("Waiting for moving thread to be completed ...") + bind_thread.join() + + new_client.connection_pool.release(new_client_conn) + new_client.close() + + client_maint_events.connection_pool.release(first_conn) + + @pytest.mark.timeout(300) # 5 minutes timeout + def test_new_connections_receive_migrating( + self, + client_maint_events: Redis, + fault_injector_client: FaultInjectorClient, + endpoints_config: Dict[str, Any], + ): + logging.info("Creating one connection in the pool.") + first_conn = client_maint_events.connection_pool.get_connection() + + logging.info("Executing rladmin migrate command...") + migrate_thread = Thread( + target=self._execute_migration, + name="migrate_thread", + args=( + fault_injector_client, + endpoints_config, + self.target_node, + self.empty_node, + ), + ) + migrate_thread.start() + + logging.info("Waiting for MIGRATING push notifications...") + # this will consume the notification in the provided connection + ClientValidations.wait_push_notification( + client_maint_events, timeout=MIGRATE_TIMEOUT, connection=first_conn + ) + + self._validate_maintenance_state( + client_maint_events, expected_matching_conns_count=1 + ) + + # validate that new connections will also receive the migrating event + # it should be received as part of the client connection setup flow + logging.info( + "Creating second connection that should receive the migrating event as well." + ) + second_connection = client_maint_events.connection_pool.get_connection() + self._validate_maintenance_state( + client_maint_events, expected_matching_conns_count=2 + ) + + logging.info("Waiting for MIGRATED push notifications on both connections ...") + ClientValidations.wait_push_notification( + client_maint_events, timeout=MIGRATE_TIMEOUT, connection=first_conn + ) + ClientValidations.wait_push_notification( + client_maint_events, timeout=MIGRATE_TIMEOUT, connection=second_connection + ) + + migrate_thread.join() + + client_maint_events.connection_pool.release(first_conn) + client_maint_events.connection_pool.release(second_connection) + @pytest.mark.timeout(300) def test_disabled_handling_during_migrating_and_moving( self, @@ -637,24 +946,24 @@ def test_disabled_handling_during_migrating_and_moving( migrate_thread.start() logging.info("Waiting for MIGRATING push notifications...") - # this will consume the notification in the provided connection + # this will consume the notification in the provided connection if it arrives ClientValidations.wait_push_notification( - client, timeout=5, connection=first_conn + client, timeout=5, fail_on_timeout=False, connection=first_conn ) self._validate_default_notif_disabled_state( client, expected_matching_conns_count=1 ) - # validate that new connections will also receive the moving event + # validate that new connections will not receive the migrating event logging.info( "Creating second connection in the pool" - " and expect it to receive the migrating as well." + " and expect it not to receive the migrating as well." ) second_connection = client.connection_pool.get_connection() ClientValidations.wait_push_notification( - client, timeout=5, connection=second_connection + client, timeout=5, fail_on_timeout=False, connection=second_connection ) logging.info( @@ -666,10 +975,10 @@ def test_disabled_handling_during_migrating_and_moving( logging.info("Waiting for MIGRATED push notifications on both connections ...") ClientValidations.wait_push_notification( - client, timeout=5, connection=first_conn + client, timeout=5, fail_on_timeout=False, connection=first_conn ) ClientValidations.wait_push_notification( - client, timeout=5, connection=second_connection + client, timeout=5, fail_on_timeout=False, connection=second_connection ) client.connection_pool.release(first_conn) @@ -687,12 +996,16 @@ def test_disabled_handling_during_migrating_and_moving( bind_thread.start() logging.info("Waiting for MOVING push notifications on random connection ...") - # this will consume the notification in one of the connections + # this will consume the notification if it arrives in one of the connections # and will handle the states of the rest # the consumed connection will be disconnected during # releasing it back to the pool and as a result we will have # 3 disconnected connections in the pool - ClientValidations.wait_push_notification(client, timeout=10) + ClientValidations.wait_push_notification( + client, + timeout=10, + fail_on_timeout=False, + ) # validate that new connections will also receive the moving event connections = [] @@ -707,7 +1020,7 @@ def test_disabled_handling_during_migrating_and_moving( ) logging.info("Waiting for moving ttl to expire") - time.sleep(30) + time.sleep(DEFAULT_BIND_TTL) logging.info("Validating connection states after MOVING has expired ...") self._validate_default_notif_disabled_state(