diff --git a/redis/client.py b/redis/client.py index b8d2e8af5d..51b699721e 100755 --- a/redis/client.py +++ b/redis/client.py @@ -58,7 +58,6 @@ from redis.lock import Lock from redis.maint_notifications import ( MaintNotificationsConfig, - MaintNotificationsPoolHandler, ) from redis.retry import Retry from redis.utils import ( @@ -278,6 +277,17 @@ def __init__( single_connection_client: if `True`, connection pool is not used. In that case `Redis` instance use is not thread safe. + decode_responses: + if `True`, the response will be decoded to utf-8. + Argument is ignored when connection_pool is provided. + maint_notifications_config: + configuration the pool to support maintenance notifications - see + `redis.maint_notifications.MaintNotificationsConfig` for details. + Only supported with RESP3 + If not provided and protocol is RESP3, the maintenance notifications + will be enabled by default (logic is included in the connection pool + initialization). + Argument is ignored when connection_pool is provided. """ if event_dispatcher is None: self._event_dispatcher = EventDispatcher() @@ -354,6 +364,22 @@ def __init__( "cache_config": cache_config, } ) + maint_notifications_enabled = ( + maint_notifications_config and maint_notifications_config.enabled + ) + if maint_notifications_enabled and protocol not in [ + 3, + "3", + ]: + raise RedisError( + "Maintenance notifications handlers on connection are only supported with RESP version 3" + ) + if maint_notifications_config: + kwargs.update( + { + "maint_notifications_config": maint_notifications_config, + } + ) connection_pool = ConnectionPool(**kwargs) self._event_dispatcher.dispatch( AfterPooledConnectionsInstantiationEvent( @@ -377,23 +403,6 @@ def __init__( ]: raise RedisError("Client caching is only supported with RESP version 3") - if maint_notifications_config and self.connection_pool.get_protocol() not in [ - 3, - "3", - ]: - raise RedisError( - "Push handlers on connection are only supported with RESP version 3" - ) - if maint_notifications_config and maint_notifications_config.enabled: - self.maint_notifications_pool_handler = MaintNotificationsPoolHandler( - self.connection_pool, maint_notifications_config - ) - self.connection_pool.set_maint_notifications_pool_handler( - self.maint_notifications_pool_handler - ) - else: - self.maint_notifications_pool_handler = None - self.single_connection_lock = threading.RLock() self.connection = None self._single_connection_client = single_connection_client @@ -591,15 +600,9 @@ def monitor(self): return Monitor(self.connection_pool) def client(self): - maint_notifications_config = ( - None - if self.maint_notifications_pool_handler is None - else self.maint_notifications_pool_handler.config - ) return self.__class__( connection_pool=self.connection_pool, single_connection_client=True, - maint_notifications_config=maint_notifications_config, ) def __enter__(self): diff --git a/redis/cluster.py b/redis/cluster.py index 1d4a3e0d0c..c238c171be 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -50,6 +50,7 @@ WatchError, ) from redis.lock import Lock +from redis.maint_notifications import MaintNotificationsConfig from redis.retry import Retry from redis.utils import ( deprecated_args, @@ -1663,6 +1664,11 @@ def create_redis_node(self, host, port, **kwargs): backoff=NoBackoff(), retries=0, supported_errors=(ConnectionError,) ) + protocol = kwargs.get("protocol", None) + if protocol in [3, "3"]: + kwargs.update( + {"maint_notifications_config": MaintNotificationsConfig(enabled=False)} + ) if self.from_url: # Create a redis node with a costumed connection pool kwargs.update({"host": host}) diff --git a/redis/connection.py b/redis/connection.py index 837fccd40e..31ce30fd8d 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -5,7 +5,7 @@ import threading import time import weakref -from abc import abstractmethod +from abc import ABC, abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue from typing import ( @@ -178,10 +178,6 @@ def deregister_connect_callback(self, callback): def set_parser(self, parser_class): pass - @abstractmethod - def set_maint_notifications_pool_handler(self, maint_notifications_pool_handler): - pass - @abstractmethod def get_protocol(self): pass @@ -245,309 +241,214 @@ def set_re_auth_token(self, token: TokenInterface): def re_auth(self): pass - @property - @abstractmethod - def maintenance_state(self) -> MaintenanceState: + +class MaintNotificationsAbstractConnection: + """ + Abstract class for handling maintenance notifications logic. + This class is expected to be used as base class together with ConnectionInterface. + + This class is intended to be used with multiple inheritance! + + All logic related to maintenance notifications is encapsulated in this class. + """ + + def __init__( + self, + maint_notifications_config: Optional[MaintNotificationsConfig], + maint_notifications_pool_handler: Optional[ + MaintNotificationsPoolHandler + ] = None, + maintenance_state: "MaintenanceState" = MaintenanceState.NONE, + maintenance_notification_hash: Optional[int] = None, + orig_host_address: Optional[str] = None, + orig_socket_timeout: Optional[float] = None, + orig_socket_connect_timeout: Optional[float] = None, + parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, + ): """ - Returns the current maintenance state of the connection. + Initialize the maintenance notifications for the connection. + + Args: + maint_notifications_config (MaintNotificationsConfig): The configuration for maintenance notifications. + maint_notifications_pool_handler (Optional[MaintNotificationsPoolHandler]): The pool handler for maintenance notifications. + maintenance_state (MaintenanceState): The current maintenance state of the connection. + maintenance_notification_hash (Optional[int]): The current maintenance notification hash of the connection. + orig_host_address (Optional[str]): The original host address of the connection. + orig_socket_timeout (Optional[float]): The original socket timeout of the connection. + orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection. + parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications. + If not provided, the parser from the connection is used. + This is useful when the parser is created after this object. """ + self.maint_notifications_config = maint_notifications_config + self.maintenance_state = maintenance_state + self.maintenance_notification_hash = maintenance_notification_hash + self._configure_maintenance_notifications( + maint_notifications_pool_handler, + orig_host_address, + orig_socket_timeout, + orig_socket_connect_timeout, + parser, + ) + self._should_reconnect = False + + @abstractmethod + def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]: pass - @maintenance_state.setter @abstractmethod - def maintenance_state(self, state: "MaintenanceState"): - """ - Sets the current maintenance state of the connection. - """ + def _get_socket(self) -> Optional[socket.socket]: pass @abstractmethod - def getpeername(self): + def get_protocol(self) -> Union[int, str]: """ - Returns the peer name of the connection. + Returns: + The RESP protocol version, or ``None`` if the protocol is not specified, + in which case the server default will be used. """ pass + @property @abstractmethod - def mark_for_reconnect(self): - """ - Mark the connection to be reconnected on the next command. - This is useful when a connection is moved to a different node. - """ + def host(self) -> str: pass + @host.setter @abstractmethod - def should_reconnect(self): - """ - Returns True if the connection should be reconnected. - """ + def host(self, value: str): pass + @property @abstractmethod - def get_resolved_ip(self): - """ - Get resolved ip address for the connection. - """ + def socket_timeout(self) -> Optional[Union[float, int]]: pass + @socket_timeout.setter @abstractmethod - def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): - """ - Update the timeout for the current socket. - """ + def socket_timeout(self, value: Optional[Union[float, int]]): pass + @property @abstractmethod - def set_tmp_settings( - self, - tmp_host_address: Optional[str] = None, - tmp_relaxed_timeout: Optional[float] = None, - ): - """ - Updates temporary host address and timeout settings for the connection. - """ + def socket_connect_timeout(self) -> Optional[Union[float, int]]: pass + @socket_connect_timeout.setter @abstractmethod - def reset_tmp_settings( + def socket_connect_timeout(self, value: Optional[Union[float, int]]): + pass + + @abstractmethod + def send_command(self, *args, **kwargs): + pass + + @abstractmethod + def read_response( self, - reset_host_address: bool = False, - reset_relaxed_timeout: bool = False, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, ): - """ - Resets temporary host address and timeout settings for the connection. - """ pass + @abstractmethod + def disconnect(self, *args): + pass -class AbstractConnection(ConnectionInterface): - "Manages communication to and from a Redis server" - - def __init__( + def _configure_maintenance_notifications( self, - db: int = 0, - password: Optional[str] = None, - socket_timeout: Optional[float] = None, - socket_connect_timeout: Optional[float] = None, - retry_on_timeout: bool = False, - retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL, - encoding: str = "utf-8", - encoding_errors: str = "strict", - decode_responses: bool = False, - parser_class=DefaultParser, - socket_read_size: int = 65536, - health_check_interval: int = 0, - client_name: Optional[str] = None, - lib_name: Optional[str] = "redis-py", - lib_version: Optional[str] = get_lib_version(), - username: Optional[str] = None, - retry: Union[Any, None] = None, - redis_connect_func: Optional[Callable[[], None]] = None, - credential_provider: Optional[CredentialProvider] = None, - protocol: Optional[int] = 2, - command_packer: Optional[Callable[[], None]] = None, - event_dispatcher: Optional[EventDispatcher] = None, maint_notifications_pool_handler: Optional[ MaintNotificationsPoolHandler ] = None, - maint_notifications_config: Optional[MaintNotificationsConfig] = None, - maintenance_state: "MaintenanceState" = MaintenanceState.NONE, - maintenance_notification_hash: Optional[int] = None, - orig_host_address: Optional[str] = None, - orig_socket_timeout: Optional[float] = None, - orig_socket_connect_timeout: Optional[float] = None, + orig_host_address=None, + orig_socket_timeout=None, + orig_socket_connect_timeout=None, + parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, ): """ - Initialize a new Connection. - To specify a retry policy for specific errors, first set - `retry_on_error` to a list of the error/s to retry on, then set - `retry` to a valid `Retry` object. - To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. + Enable maintenance notifications by setting up + handlers and storing original connection parameters. + + Should be used ONLY with parsers that support push notifications. """ - if (username or password) and credential_provider is not None: - raise DataError( - "'username' and 'password' cannot be passed along with 'credential_" - "provider'. Please provide only one of the following arguments: \n" - "1. 'password' and (optional) 'username'\n" - "2. 'credential_provider'" + if ( + not self.maint_notifications_config + or not self.maint_notifications_config.enabled + ): + self._maint_notifications_pool_handler = None + self._maint_notifications_connection_handler = None + return + + if not parser: + raise RedisError( + "To configure maintenance notifications, a parser must be provided!" ) - if event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - else: - self._event_dispatcher = event_dispatcher - self.pid = os.getpid() - self.db = db - self.client_name = client_name - self.lib_name = lib_name - self.lib_version = lib_version - self.credential_provider = credential_provider - self.password = password - self.username = username - self.socket_timeout = socket_timeout - if socket_connect_timeout is None: - socket_connect_timeout = socket_timeout - self.socket_connect_timeout = socket_connect_timeout - self.retry_on_timeout = retry_on_timeout - if retry_on_error is SENTINEL: - retry_on_errors_list = [] - else: - retry_on_errors_list = list(retry_on_error) - if retry_on_timeout: - # Add TimeoutError to the errors list to retry on - retry_on_errors_list.append(TimeoutError) - self.retry_on_error = retry_on_errors_list - if retry or self.retry_on_error: - if retry is None: - self.retry = Retry(NoBackoff(), 1) - else: - # deep-copy the Retry object as it is mutable - self.retry = copy.deepcopy(retry) - if self.retry_on_error: - # Update the retry's supported errors with the specified errors - self.retry.update_supported_errors(self.retry_on_error) - else: - self.retry = Retry(NoBackoff(), 0) - self.health_check_interval = health_check_interval - self.next_health_check = 0 - self.redis_connect_func = redis_connect_func - self.encoder = Encoder(encoding, encoding_errors, decode_responses) - self.handshake_metadata = None - self._sock = None - self._socket_read_size = socket_read_size - self._connect_callbacks = [] - self._buffer_cutoff = 6000 - self._re_auth_token: Optional[TokenInterface] = None - try: - p = int(protocol) - except TypeError: - p = DEFAULT_RESP_VERSION - except ValueError: - raise ConnectionError("protocol must be an integer") - finally: - if p < 2 or p > 3: - raise ConnectionError("protocol must be either 2 or 3") - # p = DEFAULT_RESP_VERSION - self.protocol = p - if self.protocol == 3 and parser_class == DefaultParser: - parser_class = _RESP3Parser - self.set_parser(parser_class) - self.maint_notifications_config = maint_notifications_config + if not isinstance(parser, _HiredisParser) and not isinstance( + parser, _RESP3Parser + ): + raise RedisError( + "Maintenance notifications are only supported with hiredis and RESP3 parsers!" + ) - # Set up maintenance notifications if enabled - self._configure_maintenance_notifications( - maint_notifications_pool_handler, - orig_host_address, - orig_socket_timeout, - orig_socket_connect_timeout, + if maint_notifications_pool_handler: + # Extract a reference to a new pool handler that copies all properties + # of the original one and has a different connection reference + # This is needed because when we attach the handler to the parser + # we need to make sure that the handler has a reference to the + # connection that the parser is attached to. + self._maint_notifications_pool_handler = ( + maint_notifications_pool_handler.get_handler_for_connection() + ) + self._maint_notifications_pool_handler.set_connection(self) + else: + self._maint_notifications_pool_handler = None + + self._maint_notifications_connection_handler = ( + MaintNotificationsConnectionHandler(self, self.maint_notifications_config) ) - self._should_reconnect = False - self.maintenance_state = maintenance_state - self.maintenance_notification_hash = maintenance_notification_hash + # Set up pool handler if available + if self._maint_notifications_pool_handler: + parser.set_node_moving_push_handler( + self._maint_notifications_pool_handler.handle_notification + ) - self._command_packer = self._construct_command_packer(command_packer) + # Set up connection handler + parser.set_maintenance_push_handler( + self._maint_notifications_connection_handler.handle_notification + ) - def __repr__(self): - repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) - return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" + # Store original connection parameters + self.orig_host_address = orig_host_address if orig_host_address else self.host + self.orig_socket_timeout = ( + orig_socket_timeout if orig_socket_timeout else self.socket_timeout + ) + self.orig_socket_connect_timeout = ( + orig_socket_connect_timeout + if orig_socket_connect_timeout + else self.socket_connect_timeout + ) - @abstractmethod - def repr_pieces(self): - pass + def set_maint_notifications_pool_handler_for_connection( + self, maint_notifications_pool_handler: MaintNotificationsPoolHandler + ): + # Deep copy the pool handler to avoid sharing the same pool handler + # between multiple connections, because otherwise each connection will override + # the connection reference and the pool handler will only hold a reference + # to the last connection that was set. + maint_notifications_pool_handler_copy = ( + maint_notifications_pool_handler.get_handler_for_connection() + ) - def __del__(self): - try: - self.disconnect() - except Exception: - pass + maint_notifications_pool_handler_copy.set_connection(self) + self._get_parser().set_node_moving_push_handler( + maint_notifications_pool_handler_copy.handle_notification + ) - def _construct_command_packer(self, packer): - if packer is not None: - return packer - elif HIREDIS_AVAILABLE: - return HiredisRespSerializer() - else: - return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) - - def register_connect_callback(self, callback): - """ - Register a callback to be called when the connection is established either - initially or reconnected. This allows listeners to issue commands that - are ephemeral to the connection, for example pub/sub subscription or - key tracking. The callback must be a _method_ and will be kept as - a weak reference. - """ - wm = weakref.WeakMethod(callback) - if wm not in self._connect_callbacks: - self._connect_callbacks.append(wm) - - def deregister_connect_callback(self, callback): - """ - De-register a previously registered callback. It will no-longer receive - notifications on connection events. Calling this is not required when the - listener goes away, since the callbacks are kept as weak methods. - """ - try: - self._connect_callbacks.remove(weakref.WeakMethod(callback)) - except ValueError: - pass - - def set_parser(self, parser_class): - """ - Creates a new instance of parser_class with socket size: - _socket_read_size and assigns it to the parser for the connection - :param parser_class: The required parser class - """ - self._parser = parser_class(socket_read_size=self._socket_read_size) - - def _configure_maintenance_notifications( - self, - maint_notifications_pool_handler=None, - orig_host_address=None, - orig_socket_timeout=None, - orig_socket_connect_timeout=None, - ): - """Enable maintenance notifications by setting up handlers and storing original connection parameters.""" - if ( - not self.maint_notifications_config - or not self.maint_notifications_config.enabled - ): - self._maint_notifications_connection_handler = None - return - - # Set up pool handler if available - if maint_notifications_pool_handler: - self._parser.set_node_moving_push_handler( - maint_notifications_pool_handler.handle_notification - ) - - # Set up connection handler - self._maint_notifications_connection_handler = ( - MaintNotificationsConnectionHandler(self, self.maint_notifications_config) - ) - self._parser.set_maintenance_push_handler( - self._maint_notifications_connection_handler.handle_notification - ) - - # Store original connection parameters - self.orig_host_address = orig_host_address if orig_host_address else self.host - self.orig_socket_timeout = ( - orig_socket_timeout if orig_socket_timeout else self.socket_timeout - ) - self.orig_socket_connect_timeout = ( - orig_socket_connect_timeout - if orig_socket_connect_timeout - else self.socket_connect_timeout - ) - - def set_maint_notifications_pool_handler( - self, maint_notifications_pool_handler: MaintNotificationsPoolHandler - ): - maint_notifications_pool_handler.set_connection(self) - self._parser.set_node_moving_push_handler( - maint_notifications_pool_handler.handle_notification - ) + self._maint_notifications_pool_handler = maint_notifications_pool_handler_copy # Update maintenance notification connection handler if it doesn't exist if not self._maint_notifications_connection_handler: @@ -556,7 +457,7 @@ def set_maint_notifications_pool_handler( self, maint_notifications_pool_handler.config ) ) - self._parser.set_maintenance_push_handler( + self._get_parser().set_maintenance_push_handler( self._maint_notifications_connection_handler.handle_notification ) else: @@ -564,130 +465,7 @@ def set_maint_notifications_pool_handler( maint_notifications_pool_handler.config ) - def connect(self): - "Connects to the Redis server if not already connected" - self.connect_check_health(check_health=True) - - def connect_check_health( - self, check_health: bool = True, retry_socket_connect: bool = True - ): - if self._sock: - return - try: - if retry_socket_connect: - sock = self.retry.call_with_retry( - lambda: self._connect(), lambda error: self.disconnect(error) - ) - else: - sock = self._connect() - except socket.timeout: - raise TimeoutError("Timeout connecting to server") - except OSError as e: - raise ConnectionError(self._error_message(e)) - - self._sock = sock - try: - if self.redis_connect_func is None: - # Use the default on_connect function - self.on_connect_check_health(check_health=check_health) - else: - # Use the passed function redis_connect_func - self.redis_connect_func(self) - except RedisError: - # clean up after any error in on_connect - self.disconnect() - raise - - # run any user callbacks. right now the only internal callback - # is for pubsub channel/pattern resubscription - # first, remove any dead weakrefs - self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] - for ref in self._connect_callbacks: - callback = ref() - if callback: - callback(self) - - @abstractmethod - def _connect(self): - pass - - @abstractmethod - def _host_error(self): - pass - - def _error_message(self, exception): - return format_error_message(self._host_error(), exception) - - def on_connect(self): - self.on_connect_check_health(check_health=True) - - def on_connect_check_health(self, check_health: bool = True): - "Initialize the connection, authenticate and select a database" - self._parser.on_connect(self) - parser = self._parser - - auth_args = None - # if credential provider or username and/or password are set, authenticate - if self.credential_provider or (self.username or self.password): - cred_provider = ( - self.credential_provider - or UsernamePasswordCredentialProvider(self.username, self.password) - ) - auth_args = cred_provider.get_credentials() - - # if resp version is specified and we have auth args, - # we need to send them via HELLO - if auth_args and self.protocol not in [2, "2"]: - if isinstance(self._parser, _RESP2Parser): - self.set_parser(_RESP3Parser) - # update cluster exception classes - self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES - self._parser.on_connect(self) - if len(auth_args) == 1: - auth_args = ["default", auth_args[0]] - # avoid checking health here -- PING will fail if we try - # to check the health prior to the AUTH - self.send_command( - "HELLO", self.protocol, "AUTH", *auth_args, check_health=False - ) - self.handshake_metadata = self.read_response() - # if response.get(b"proto") != self.protocol and response.get( - # "proto" - # ) != self.protocol: - # raise ConnectionError("Invalid RESP version") - elif auth_args: - # avoid checking health here -- PING will fail if we try - # to check the health prior to the AUTH - self.send_command("AUTH", *auth_args, check_health=False) - - try: - auth_response = self.read_response() - except AuthenticationWrongNumberOfArgsError: - # a username and password were specified but the Redis - # server seems to be < 6.0.0 which expects a single password - # arg. retry auth with just the password. - # https://github.com/andymccurdy/redis-py/issues/1274 - self.send_command("AUTH", auth_args[-1], check_health=False) - auth_response = self.read_response() - - if str_if_bytes(auth_response) != "OK": - raise AuthenticationError("Invalid Username or Password") - - # if resp version is specified, switch to it - elif self.protocol not in [2, "2"]: - if isinstance(self._parser, _RESP2Parser): - self.set_parser(_RESP3Parser) - # update cluster exception classes - self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES - self._parser.on_connect(self) - self.send_command("HELLO", self.protocol, check_health=check_health) - self.handshake_metadata = self.read_response() - if ( - self.handshake_metadata.get(b"proto") != self.protocol - and self.handshake_metadata.get("proto") != self.protocol - ): - raise ConnectionError("Invalid RESP version") - + def activate_maint_notifications_handling_if_enabled(self, check_health=True): # Send maintenance notifications handshake if RESP3 is active # and maintenance notifications are enabled # and we have a host to determine the endpoint type from @@ -695,16 +473,29 @@ def on_connect_check_health(self, check_health: bool = True): # we just log a warning if the handshake fails # When the mode is enabled=True, we raise an exception in case of failure if ( - self.protocol not in [2, "2"] + self.get_protocol() not in [2, "2"] and self.maint_notifications_config and self.maint_notifications_config.enabled and self._maint_notifications_connection_handler and hasattr(self, "host") ): - try: - endpoint_type = self.maint_notifications_config.get_endpoint_type( - self.host, self + self._enable_maintenance_notifications( + maint_notifications_config=self.maint_notifications_config, + check_health=check_health, + ) + + def _enable_maintenance_notifications( + self, maint_notifications_config: MaintNotificationsConfig, check_health=True + ): + try: + host = getattr(self, "host", None) + if host is None: + raise ValueError( + "Cannot enable maintenance notifications for connection" + " object that doesn't have a host attribute." ) + else: + endpoint_type = maint_notifications_config.get_endpoint_type(host, self) self.send_command( "CLIENT", "MAINT_NOTIFICATIONS", @@ -714,289 +505,54 @@ def on_connect_check_health(self, check_health: bool = True): check_health=check_health, ) response = self.read_response() - if str_if_bytes(response) != "OK": + if not response or str_if_bytes(response) != "OK": raise ResponseError( "The server doesn't support maintenance notifications" ) - except Exception as e: - if ( - isinstance(e, ResponseError) - and self.maint_notifications_config.enabled == "auto" - ): - # Log warning but don't fail the connection - import logging + except Exception as e: + if ( + isinstance(e, ResponseError) + and maint_notifications_config.enabled == "auto" + ): + # Log warning but don't fail the connection + import logging - logger = logging.getLogger(__name__) - logger.warning(f"Failed to enable maintenance notifications: {e}") - else: - raise + logger = logging.getLogger(__name__) + logger.warning(f"Failed to enable maintenance notifications: {e}") + else: + raise - # if a client_name is given, set it - if self.client_name: - self.send_command( - "CLIENT", - "SETNAME", - self.client_name, - check_health=check_health, - ) - if str_if_bytes(self.read_response()) != "OK": - raise ConnectionError("Error setting client name") + def get_resolved_ip(self) -> Optional[str]: + """ + Extract the resolved IP address from an + established connection or resolve it from the host. + + First tries to get the actual IP from the socket (most accurate), + then falls back to DNS resolution if needed. + + Args: + connection: The connection object to extract the IP from + + Returns: + str: The resolved IP address, or None if it cannot be determined + """ + # Method 1: Try to get the actual IP from the established socket connection + # This is most accurate as it shows the exact IP being used try: - # set the library name and version - if self.lib_name: - self.send_command( - "CLIENT", - "SETINFO", - "LIB-NAME", - self.lib_name, - check_health=check_health, - ) - self.read_response() - except ResponseError: + conn_socket = self._get_socket() + if conn_socket is not None: + peer_addr = conn_socket.getpeername() + if peer_addr and len(peer_addr) >= 1: + # For TCP sockets, peer_addr is typically (host, port) tuple + # Return just the host part + return peer_addr[0] + except (AttributeError, OSError): + # Socket might not be connected or getpeername() might fail pass - try: - if self.lib_version: - self.send_command( - "CLIENT", - "SETINFO", - "LIB-VER", - self.lib_version, - check_health=check_health, - ) - self.read_response() - except ResponseError: - pass - - # if a database is specified, switch to it - if self.db: - self.send_command("SELECT", self.db, check_health=check_health) - if str_if_bytes(self.read_response()) != "OK": - raise ConnectionError("Invalid Database") - - def disconnect(self, *args): - "Disconnects from the Redis server" - self._parser.on_disconnect() - - conn_sock = self._sock - self._sock = None - # reset the reconnect flag - self._should_reconnect = False - if conn_sock is None: - return - - if os.getpid() == self.pid: - try: - conn_sock.shutdown(socket.SHUT_RDWR) - except (OSError, TypeError): - pass - - try: - conn_sock.close() - except OSError: - pass - - def _send_ping(self): - """Send PING, expect PONG in return""" - self.send_command("PING", check_health=False) - if str_if_bytes(self.read_response()) != "PONG": - raise ConnectionError("Bad response from PING health check") - - def _ping_failed(self, error): - """Function to call when PING fails""" - self.disconnect() - - def check_health(self): - """Check the health of the connection with a PING/PONG""" - if self.health_check_interval and time.monotonic() > self.next_health_check: - self.retry.call_with_retry(self._send_ping, self._ping_failed) - - def send_packed_command(self, command, check_health=True): - """Send an already packed command to the Redis server""" - if not self._sock: - self.connect_check_health(check_health=False) - # guard against health check recursion - if check_health: - self.check_health() - try: - if isinstance(command, str): - command = [command] - for item in command: - self._sock.sendall(item) - except socket.timeout: - self.disconnect() - raise TimeoutError("Timeout writing to socket") - except OSError as e: - self.disconnect() - if len(e.args) == 1: - errno, errmsg = "UNKNOWN", e.args[0] - else: - errno = e.args[0] - errmsg = e.args[1] - raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") - except BaseException: - # BaseExceptions can be raised when a socket send operation is not - # finished, e.g. due to a timeout. Ideally, a caller could then re-try - # to send un-sent data. However, the send_packed_command() API - # does not support it so there is no point in keeping the connection open. - self.disconnect() - raise - - def send_command(self, *args, **kwargs): - """Pack and send a command to the Redis server""" - self.send_packed_command( - self._command_packer.pack(*args), - check_health=kwargs.get("check_health", True), - ) - - def can_read(self, timeout=0): - """Poll the socket to see if there's data that can be read.""" - sock = self._sock - if not sock: - self.connect() - - host_error = self._host_error() - - try: - return self._parser.can_read(timeout) - - except OSError as e: - self.disconnect() - raise ConnectionError(f"Error while reading from {host_error}: {e.args}") - - def read_response( - self, - disable_decoding=False, - *, - disconnect_on_error=True, - push_request=False, - ): - """Read the response from a previously sent command""" - - host_error = self._host_error() - - try: - if self.protocol in ["3", 3]: - response = self._parser.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - response = self._parser.read_response(disable_decoding=disable_decoding) - except socket.timeout: - if disconnect_on_error: - self.disconnect() - raise TimeoutError(f"Timeout reading from {host_error}") - except OSError as e: - if disconnect_on_error: - self.disconnect() - raise ConnectionError(f"Error while reading from {host_error} : {e.args}") - except BaseException: - # Also by default close in case of BaseException. A lot of code - # relies on this behaviour when doing Command/Response pairs. - # See #1128. - if disconnect_on_error: - self.disconnect() - raise - - if self.health_check_interval: - self.next_health_check = time.monotonic() + self.health_check_interval - - if isinstance(response, ResponseError): - try: - raise response - finally: - del response # avoid creating ref cycles - return response - - def pack_command(self, *args): - """Pack a series of arguments into the Redis protocol""" - return self._command_packer.pack(*args) - - def pack_commands(self, commands): - """Pack multiple commands into the Redis protocol""" - output = [] - pieces = [] - buffer_length = 0 - buffer_cutoff = self._buffer_cutoff - - for cmd in commands: - for chunk in self._command_packer.pack(*cmd): - chunklen = len(chunk) - if ( - buffer_length > buffer_cutoff - or chunklen > buffer_cutoff - or isinstance(chunk, memoryview) - ): - if pieces: - output.append(SYM_EMPTY.join(pieces)) - buffer_length = 0 - pieces = [] - - if chunklen > buffer_cutoff or isinstance(chunk, memoryview): - output.append(chunk) - else: - pieces.append(chunk) - buffer_length += chunklen - - if pieces: - output.append(SYM_EMPTY.join(pieces)) - return output - - def get_protocol(self) -> Union[int, str]: - return self.protocol - - @property - def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: - return self._handshake_metadata - - @handshake_metadata.setter - def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): - self._handshake_metadata = value - - def set_re_auth_token(self, token: TokenInterface): - self._re_auth_token = token - - def re_auth(self): - if self._re_auth_token is not None: - self.send_command( - "AUTH", - self._re_auth_token.try_get("oid"), - self._re_auth_token.get_value(), - ) - self.read_response() - self._re_auth_token = None - - def get_resolved_ip(self) -> Optional[str]: - """ - Extract the resolved IP address from an - established connection or resolve it from the host. - - First tries to get the actual IP from the socket (most accurate), - then falls back to DNS resolution if needed. - - Args: - connection: The connection object to extract the IP from - - Returns: - str: The resolved IP address, or None if it cannot be determined - """ - - # Method 1: Try to get the actual IP from the established socket connection - # This is most accurate as it shows the exact IP being used - try: - if self._sock is not None: - peer_addr = self._sock.getpeername() - if peer_addr and len(peer_addr) >= 1: - # For TCP sockets, peer_addr is typically (host, port) tuple - # Return just the host part - return peer_addr[0] - except (AttributeError, OSError): - # Socket might not be connected or getpeername() might fail - pass - - # Method 2: Fallback to DNS resolution of the host - # This is less accurate but works when socket is not available + # Method 2: Fallback to DNS resolution of the host + # This is less accurate but works when socket is not available try: host = getattr(self, "host", "localhost") port = getattr(self, "port", 6379) @@ -1010,7 +566,7 @@ def get_resolved_ip(self) -> Optional[str]: # Return the IP from the first result # addr_info[0] is (family, socktype, proto, canonname, sockaddr) # sockaddr[0] is the IP address - return addr_info[0][4][0] + return str(addr_info[0][4][0]) except (AttributeError, OSError, socket.gaierror): # DNS resolution might fail pass @@ -1026,9 +582,13 @@ def maintenance_state(self, state: "MaintenanceState"): self._maintenance_state = state def getpeername(self): - if not self._sock: - return None - return self._sock.getpeername()[0] + """ + Returns the peer name of the connection. + """ + conn_socket = self._get_socket() + if conn_socket: + return conn_socket.getpeername()[0] + return None def mark_for_reconnect(self): self._should_reconnect = True @@ -1036,15 +596,23 @@ def mark_for_reconnect(self): def should_reconnect(self): return self._should_reconnect + def reset_should_reconnect(self): + self._should_reconnect = False + def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): - if self._sock: + conn_socket = self._get_socket() + if conn_socket: timeout = relaxed_timeout if relaxed_timeout != -1 else self.socket_timeout - self._sock.settimeout(timeout) - self.update_parser_buffer_timeout(timeout) + conn_socket.settimeout(timeout) + self.update_parser_timeout(timeout) - def update_parser_buffer_timeout(self, timeout: Optional[float] = None): - if self._parser and self._parser._buffer: - self._parser._buffer.socket_timeout = timeout + def update_parser_timeout(self, timeout: Optional[float] = None): + parser = self._get_parser() + if parser and parser._buffer: + if isinstance(parser, _RESP3Parser) and timeout: + parser._buffer.socket_timeout = timeout + elif isinstance(parser, _HiredisParser): + parser._socket_timeout = timeout def set_tmp_settings( self, @@ -1054,8 +622,8 @@ def set_tmp_settings( """ The value of SENTINEL is used to indicate that the property should not be updated. """ - if tmp_host_address is not SENTINEL: - self.host = tmp_host_address + if tmp_host_address and tmp_host_address != SENTINEL: + self.host = str(tmp_host_address) if tmp_relaxed_timeout != -1: self.socket_timeout = tmp_relaxed_timeout self.socket_connect_timeout = tmp_relaxed_timeout @@ -1072,20 +640,593 @@ def reset_tmp_settings( self.socket_connect_timeout = self.orig_socket_connect_timeout -class Connection(AbstractConnection): - "Manages TCP communication to and from a Redis server" +class AbstractConnection(MaintNotificationsAbstractConnection, ConnectionInterface): + "Manages communication to and from a Redis server" def __init__( self, - host="localhost", - port=6379, - socket_keepalive=False, - socket_keepalive_options=None, - socket_type=0, - **kwargs, - ): - self.host = host - self.port = int(port) + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + retry_on_timeout: bool = False, + retry_on_error: Union[Iterable[Type[Exception]], object] = SENTINEL, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class=DefaultParser, + socket_read_size: int = 65536, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, + retry: Union[Any, None] = None, + redis_connect_func: Optional[Callable[[], None]] = None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + command_packer: Optional[Callable[[], None]] = None, + event_dispatcher: Optional[EventDispatcher] = None, + maint_notifications_config: Optional[MaintNotificationsConfig] = None, + maint_notifications_pool_handler: Optional[ + MaintNotificationsPoolHandler + ] = None, + maintenance_state: "MaintenanceState" = MaintenanceState.NONE, + maintenance_notification_hash: Optional[int] = None, + orig_host_address: Optional[str] = None, + orig_socket_timeout: Optional[float] = None, + orig_socket_connect_timeout: Optional[float] = None, + ): + """ + Initialize a new Connection. + To specify a retry policy for specific errors, first set + `retry_on_error` to a list of the error/s to retry on, then set + `retry` to a valid `Retry` object. + To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. + """ + if (username or password) and credential_provider is not None: + raise DataError( + "'username' and 'password' cannot be passed along with 'credential_" + "provider'. Please provide only one of the following arguments: \n" + "1. 'password' and (optional) 'username'\n" + "2. 'credential_provider'" + ) + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher + self.pid = os.getpid() + self.db = db + self.client_name = client_name + self.lib_name = lib_name + self.lib_version = lib_version + self.credential_provider = credential_provider + self.password = password + self.username = username + self._socket_timeout = socket_timeout + if socket_connect_timeout is None: + socket_connect_timeout = socket_timeout + self._socket_connect_timeout = socket_connect_timeout + self.retry_on_timeout = retry_on_timeout + if retry_on_error is SENTINEL: + retry_on_errors_list = [] + else: + retry_on_errors_list = list(retry_on_error) + if retry_on_timeout: + # Add TimeoutError to the errors list to retry on + retry_on_errors_list.append(TimeoutError) + self.retry_on_error = retry_on_errors_list + if retry or self.retry_on_error: + if retry is None: + self.retry = Retry(NoBackoff(), 1) + else: + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + if self.retry_on_error: + # Update the retry's supported errors with the specified errors + self.retry.update_supported_errors(self.retry_on_error) + else: + self.retry = Retry(NoBackoff(), 0) + self.health_check_interval = health_check_interval + self.next_health_check = 0 + self.redis_connect_func = redis_connect_func + self.encoder = Encoder(encoding, encoding_errors, decode_responses) + self.handshake_metadata = None + self._sock = None + self._socket_read_size = socket_read_size + self._connect_callbacks = [] + self._buffer_cutoff = 6000 + self._re_auth_token: Optional[TokenInterface] = None + try: + p = int(protocol) + except TypeError: + p = DEFAULT_RESP_VERSION + except ValueError: + raise ConnectionError("protocol must be an integer") + finally: + if p < 2 or p > 3: + raise ConnectionError("protocol must be either 2 or 3") + # p = DEFAULT_RESP_VERSION + self.protocol = p + if self.protocol == 3 and parser_class == _RESP2Parser: + # If the protocol is 3 but the parser is RESP2, change it to RESP3 + # This is needed because the parser might be set before the protocol + # or might be provided as a kwarg to the constructor + # We need to react on discrepancy only for RESP2 and RESP3 + # as hiredis supports both + parser_class = _RESP3Parser + self.set_parser(parser_class) + + self._command_packer = self._construct_command_packer(command_packer) + + # Set up maintenance notifications + MaintNotificationsAbstractConnection.__init__( + self, + maint_notifications_config, + maint_notifications_pool_handler, + maintenance_state, + maintenance_notification_hash, + orig_host_address, + orig_socket_timeout, + orig_socket_connect_timeout, + self._parser, + ) + + def __repr__(self): + repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) + return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" + + @abstractmethod + def repr_pieces(self): + pass + + def __del__(self): + try: + self.disconnect() + except Exception: + pass + + def _construct_command_packer(self, packer): + if packer is not None: + return packer + elif HIREDIS_AVAILABLE: + return HiredisRespSerializer() + else: + return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) + + def register_connect_callback(self, callback): + """ + Register a callback to be called when the connection is established either + initially or reconnected. This allows listeners to issue commands that + are ephemeral to the connection, for example pub/sub subscription or + key tracking. The callback must be a _method_ and will be kept as + a weak reference. + """ + wm = weakref.WeakMethod(callback) + if wm not in self._connect_callbacks: + self._connect_callbacks.append(wm) + + def deregister_connect_callback(self, callback): + """ + De-register a previously registered callback. It will no-longer receive + notifications on connection events. Calling this is not required when the + listener goes away, since the callbacks are kept as weak methods. + """ + try: + self._connect_callbacks.remove(weakref.WeakMethod(callback)) + except ValueError: + pass + + def set_parser(self, parser_class): + """ + Creates a new instance of parser_class with socket size: + _socket_read_size and assigns it to the parser for the connection + :param parser_class: The required parser class + """ + self._parser = parser_class(socket_read_size=self._socket_read_size) + + def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser, _RESP2Parser]: + return self._parser + + def connect(self): + "Connects to the Redis server if not already connected" + self.connect_check_health(check_health=True) + + def connect_check_health( + self, check_health: bool = True, retry_socket_connect: bool = True + ): + if self._sock: + return + try: + if retry_socket_connect: + sock = self.retry.call_with_retry( + lambda: self._connect(), lambda error: self.disconnect(error) + ) + else: + sock = self._connect() + except socket.timeout: + raise TimeoutError("Timeout connecting to server") + except OSError as e: + raise ConnectionError(self._error_message(e)) + + self._sock = sock + try: + if self.redis_connect_func is None: + # Use the default on_connect function + self.on_connect_check_health(check_health=check_health) + else: + # Use the passed function redis_connect_func + self.redis_connect_func(self) + except RedisError: + # clean up after any error in on_connect + self.disconnect() + raise + + # run any user callbacks. right now the only internal callback + # is for pubsub channel/pattern resubscription + # first, remove any dead weakrefs + self._connect_callbacks = [ref for ref in self._connect_callbacks if ref()] + for ref in self._connect_callbacks: + callback = ref() + if callback: + callback(self) + + @abstractmethod + def _connect(self): + pass + + @abstractmethod + def _host_error(self): + pass + + def _error_message(self, exception): + return format_error_message(self._host_error(), exception) + + def on_connect(self): + self.on_connect_check_health(check_health=True) + + def on_connect_check_health(self, check_health: bool = True): + "Initialize the connection, authenticate and select a database" + self._parser.on_connect(self) + parser = self._parser + + auth_args = None + # if credential provider or username and/or password are set, authenticate + if self.credential_provider or (self.username or self.password): + cred_provider = ( + self.credential_provider + or UsernamePasswordCredentialProvider(self.username, self.password) + ) + auth_args = cred_provider.get_credentials() + + # if resp version is specified and we have auth args, + # we need to send them via HELLO + if auth_args and self.protocol not in [2, "2"]: + if isinstance(self._parser, _RESP2Parser): + self.set_parser(_RESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + if len(auth_args) == 1: + auth_args = ["default", auth_args[0]] + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + self.send_command( + "HELLO", self.protocol, "AUTH", *auth_args, check_health=False + ) + self.handshake_metadata = self.read_response() + # if response.get(b"proto") != self.protocol and response.get( + # "proto" + # ) != self.protocol: + # raise ConnectionError("Invalid RESP version") + elif auth_args: + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + self.send_command("AUTH", *auth_args, check_health=False) + + try: + auth_response = self.read_response() + except AuthenticationWrongNumberOfArgsError: + # a username and password were specified but the Redis + # server seems to be < 6.0.0 which expects a single password + # arg. retry auth with just the password. + # https://github.com/andymccurdy/redis-py/issues/1274 + self.send_command("AUTH", auth_args[-1], check_health=False) + auth_response = self.read_response() + + if str_if_bytes(auth_response) != "OK": + raise AuthenticationError("Invalid Username or Password") + + # if resp version is specified, switch to it + elif self.protocol not in [2, "2"]: + if isinstance(self._parser, _RESP2Parser): + self.set_parser(_RESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + self.send_command("HELLO", self.protocol, check_health=check_health) + self.handshake_metadata = self.read_response() + if ( + self.handshake_metadata.get(b"proto") != self.protocol + and self.handshake_metadata.get("proto") != self.protocol + ): + raise ConnectionError("Invalid RESP version") + + # Activate maintenance notifications for this connection + # if enabled in the configuration + # This is a no-op if maintenance notifications are not enabled + self.activate_maint_notifications_handling_if_enabled(check_health=check_health) + + # if a client_name is given, set it + if self.client_name: + self.send_command( + "CLIENT", + "SETNAME", + self.client_name, + check_health=check_health, + ) + if str_if_bytes(self.read_response()) != "OK": + raise ConnectionError("Error setting client name") + + try: + # set the library name and version + if self.lib_name: + self.send_command( + "CLIENT", + "SETINFO", + "LIB-NAME", + self.lib_name, + check_health=check_health, + ) + self.read_response() + except ResponseError: + pass + + try: + if self.lib_version: + self.send_command( + "CLIENT", + "SETINFO", + "LIB-VER", + self.lib_version, + check_health=check_health, + ) + self.read_response() + except ResponseError: + pass + + # if a database is specified, switch to it + if self.db: + self.send_command("SELECT", self.db, check_health=check_health) + if str_if_bytes(self.read_response()) != "OK": + raise ConnectionError("Invalid Database") + + def disconnect(self, *args): + "Disconnects from the Redis server" + self._parser.on_disconnect() + + conn_sock = self._sock + self._sock = None + # reset the reconnect flag + self.reset_should_reconnect() + if conn_sock is None: + return + + if os.getpid() == self.pid: + try: + conn_sock.shutdown(socket.SHUT_RDWR) + except (OSError, TypeError): + pass + + try: + conn_sock.close() + except OSError: + pass + + def _send_ping(self): + """Send PING, expect PONG in return""" + self.send_command("PING", check_health=False) + if str_if_bytes(self.read_response()) != "PONG": + raise ConnectionError("Bad response from PING health check") + + def _ping_failed(self, error): + """Function to call when PING fails""" + self.disconnect() + + def check_health(self): + """Check the health of the connection with a PING/PONG""" + if self.health_check_interval and time.monotonic() > self.next_health_check: + self.retry.call_with_retry(self._send_ping, self._ping_failed) + + def send_packed_command(self, command, check_health=True): + """Send an already packed command to the Redis server""" + if not self._sock: + self.connect_check_health(check_health=False) + # guard against health check recursion + if check_health: + self.check_health() + try: + if isinstance(command, str): + command = [command] + for item in command: + self._sock.sendall(item) + except socket.timeout: + self.disconnect() + raise TimeoutError("Timeout writing to socket") + except OSError as e: + self.disconnect() + if len(e.args) == 1: + errno, errmsg = "UNKNOWN", e.args[0] + else: + errno = e.args[0] + errmsg = e.args[1] + raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") + except BaseException: + # BaseExceptions can be raised when a socket send operation is not + # finished, e.g. due to a timeout. Ideally, a caller could then re-try + # to send un-sent data. However, the send_packed_command() API + # does not support it so there is no point in keeping the connection open. + self.disconnect() + raise + + def send_command(self, *args, **kwargs): + """Pack and send a command to the Redis server""" + self.send_packed_command( + self._command_packer.pack(*args), + check_health=kwargs.get("check_health", True), + ) + + def can_read(self, timeout=0): + """Poll the socket to see if there's data that can be read.""" + sock = self._sock + if not sock: + self.connect() + + host_error = self._host_error() + + try: + return self._parser.can_read(timeout) + + except OSError as e: + self.disconnect() + raise ConnectionError(f"Error while reading from {host_error}: {e.args}") + + def read_response( + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, + ): + """Read the response from a previously sent command""" + + host_error = self._host_error() + + try: + if self.protocol in ["3", 3]: + response = self._parser.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + response = self._parser.read_response(disable_decoding=disable_decoding) + except socket.timeout: + if disconnect_on_error: + self.disconnect() + raise TimeoutError(f"Timeout reading from {host_error}") + except OSError as e: + if disconnect_on_error: + self.disconnect() + raise ConnectionError(f"Error while reading from {host_error} : {e.args}") + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + self.disconnect() + raise + + if self.health_check_interval: + self.next_health_check = time.monotonic() + self.health_check_interval + + if isinstance(response, ResponseError): + try: + raise response + finally: + del response # avoid creating ref cycles + return response + + def pack_command(self, *args): + """Pack a series of arguments into the Redis protocol""" + return self._command_packer.pack(*args) + + def pack_commands(self, commands): + """Pack multiple commands into the Redis protocol""" + output = [] + pieces = [] + buffer_length = 0 + buffer_cutoff = self._buffer_cutoff + + for cmd in commands: + for chunk in self._command_packer.pack(*cmd): + chunklen = len(chunk) + if ( + buffer_length > buffer_cutoff + or chunklen > buffer_cutoff + or isinstance(chunk, memoryview) + ): + if pieces: + output.append(SYM_EMPTY.join(pieces)) + buffer_length = 0 + pieces = [] + + if chunklen > buffer_cutoff or isinstance(chunk, memoryview): + output.append(chunk) + else: + pieces.append(chunk) + buffer_length += chunklen + + if pieces: + output.append(SYM_EMPTY.join(pieces)) + return output + + def get_protocol(self) -> Union[int, str]: + return self.protocol + + @property + def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: + return self._handshake_metadata + + @handshake_metadata.setter + def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): + self._handshake_metadata = value + + def set_re_auth_token(self, token: TokenInterface): + self._re_auth_token = token + + def re_auth(self): + if self._re_auth_token is not None: + self.send_command( + "AUTH", + self._re_auth_token.try_get("oid"), + self._re_auth_token.get_value(), + ) + self.read_response() + self._re_auth_token = None + + def _get_socket(self) -> Optional[socket.socket]: + return self._sock + + @property + def socket_timeout(self) -> Optional[Union[float, int]]: + return self._socket_timeout + + @socket_timeout.setter + def socket_timeout(self, value: Optional[Union[float, int]]): + self._socket_timeout = value + + @property + def socket_connect_timeout(self) -> Optional[Union[float, int]]: + return self._socket_connect_timeout + + @socket_connect_timeout.setter + def socket_connect_timeout(self, value: Optional[Union[float, int]]): + self._socket_connect_timeout = value + + +class Connection(AbstractConnection): + "Manages TCP communication to and from a Redis server" + + def __init__( + self, + host="localhost", + port=6379, + socket_keepalive=False, + socket_keepalive_options=None, + socket_type=0, + **kwargs, + ): + self._host = host + self.port = int(port) self.socket_keepalive = socket_keepalive self.socket_keepalive_options = socket_keepalive_options or {} self.socket_type = socket_type @@ -1146,8 +1287,16 @@ def _connect(self): def _host_error(self): return f"{self.host}:{self.port}" + @property + def host(self) -> str: + return self._host + + @host.setter + def host(self, value: str): + self._host = value -class CacheProxyConnection(ConnectionInterface): + +class CacheProxyConnection(MaintNotificationsAbstractConnection, ConnectionInterface): DUMMY_CACHE_VALUE = b"foo" MIN_ALLOWED_VERSION = "7.4.0" DEFAULT_SERVER_NAME = "redis" @@ -1171,6 +1320,19 @@ def __init__( self._current_options = None self.register_connect_callback(self._enable_tracking_callback) + if isinstance(self._conn, MaintNotificationsAbstractConnection): + MaintNotificationsAbstractConnection.__init__( + self, + self._conn.maint_notifications_config, + self._conn._maint_notifications_pool_handler, + self._conn.maintenance_state, + self._conn.maintenance_notification_hash, + self._conn.host, + self._conn.socket_timeout, + self._conn.socket_connect_timeout, + self._conn._get_parser(), + ) + def repr_pieces(self): return self._conn.repr_pieces() @@ -1183,6 +1345,17 @@ def deregister_connect_callback(self, callback): def set_parser(self, parser_class): self._conn.set_parser(parser_class) + def set_maint_notifications_pool_handler_for_connection( + self, maint_notifications_pool_handler + ): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.set_maint_notifications_pool_handler_for_connection( + maint_notifications_pool_handler + ) + + def get_protocol(self): + return self._conn.get_protocol() + def connect(self): self._conn.connect() @@ -1328,6 +1501,134 @@ def pack_commands(self, commands): def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: return self._conn.handshake_metadata + def set_re_auth_token(self, token: TokenInterface): + self._conn.set_re_auth_token(token) + + def re_auth(self): + self._conn.re_auth() + + @property + def host(self) -> str: + return self._conn.host + + @host.setter + def host(self, value: str): + self._conn.host = value + + @property + def socket_timeout(self) -> Optional[Union[float, int]]: + return self._conn.socket_timeout + + @socket_timeout.setter + def socket_timeout(self, value: Optional[Union[float, int]]): + self._conn.socket_timeout = value + + @property + def socket_connect_timeout(self) -> Optional[Union[float, int]]: + return self._conn.socket_connect_timeout + + @socket_connect_timeout.setter + def socket_connect_timeout(self, value: Optional[Union[float, int]]): + self._conn.socket_connect_timeout = value + + def _get_socket(self) -> Optional[socket.socket]: + if isinstance(self._conn, MaintNotificationsAbstractConnection): + return self._conn._get_socket() + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + @property + def maintenance_state(self) -> MaintenanceState: + if isinstance(self._conn, MaintNotificationsAbstractConnection): + return self._conn.maintenance_state + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + @maintenance_state.setter + def maintenance_state(self, state: MaintenanceState): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.maintenance_state = state + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def getpeername(self): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + return self._conn.getpeername() + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def mark_for_reconnect(self): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.mark_for_reconnect() + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def should_reconnect(self): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + return self._conn.should_reconnect() + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def reset_should_reconnect(self): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.reset_should_reconnect() + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def get_resolved_ip(self): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + return self._conn.get_resolved_ip() + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def update_current_socket_timeout(self, relaxed_timeout: Optional[float] = None): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.update_current_socket_timeout(relaxed_timeout) + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def set_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relaxed_timeout: Optional[float] = None, + ): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.set_tmp_settings(tmp_host_address, tmp_relaxed_timeout) + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + + def reset_tmp_settings( + self, + reset_host_address: bool = False, + reset_relaxed_timeout: bool = False, + ): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.reset_tmp_settings(reset_host_address, reset_relaxed_timeout) + else: + raise NotImplementedError( + "Maintenance notifications are not supported by this connection type" + ) + def _connect(self): self._conn._connect() @@ -1351,15 +1652,6 @@ def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]] else: self._cache.delete_by_redis_keys(data[1]) - def get_protocol(self): - return self._conn.get_protocol() - - def set_re_auth_token(self, token: TokenInterface): - self._conn.set_re_auth_token(token) - - def re_auth(self): - self._conn.re_auth() - class SSLConnection(Connection): """Manages SSL connections to and from the Redis server(s). @@ -1448,240 +1740,629 @@ def __init__( self.ssl_ciphers = ssl_ciphers super().__init__(**kwargs) - def _connect(self): - """ - Wrap the socket with SSL support, handling potential errors. - """ - sock = super()._connect() - try: - return self._wrap_socket_with_ssl(sock) - except (OSError, RedisError): - sock.close() - raise + def _connect(self): + """ + Wrap the socket with SSL support, handling potential errors. + """ + sock = super()._connect() + try: + return self._wrap_socket_with_ssl(sock) + except (OSError, RedisError): + sock.close() + raise + + def _wrap_socket_with_ssl(self, sock): + """ + Wraps the socket with SSL support. + + Args: + sock: The plain socket to wrap with SSL. + + Returns: + An SSL wrapped socket. + """ + context = ssl.create_default_context() + context.check_hostname = self.check_hostname + context.verify_mode = self.cert_reqs + if self.ssl_include_verify_flags: + for flag in self.ssl_include_verify_flags: + context.verify_flags |= flag + if self.ssl_exclude_verify_flags: + for flag in self.ssl_exclude_verify_flags: + context.verify_flags &= ~flag + if self.certfile or self.keyfile: + context.load_cert_chain( + certfile=self.certfile, + keyfile=self.keyfile, + password=self.certificate_password, + ) + if ( + self.ca_certs is not None + or self.ca_path is not None + or self.ca_data is not None + ): + context.load_verify_locations( + cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data + ) + if self.ssl_min_version is not None: + context.minimum_version = self.ssl_min_version + if self.ssl_ciphers: + context.set_ciphers(self.ssl_ciphers) + if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: + raise RedisError("cryptography is not installed.") + + if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: + raise RedisError( + "Either an OCSP staple or pure OCSP connection must be validated " + "- not both." + ) + + sslsock = context.wrap_socket(sock, server_hostname=self.host) + + # validation for the stapled case + if self.ssl_validate_ocsp_stapled: + import OpenSSL + + from .ocsp import ocsp_staple_verifier + + # if a context is provided use it - otherwise, a basic context + if self.ssl_ocsp_context is None: + staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) + staple_ctx.use_certificate_file(self.certfile) + staple_ctx.use_privatekey_file(self.keyfile) + else: + staple_ctx = self.ssl_ocsp_context + + staple_ctx.set_ocsp_client_callback( + ocsp_staple_verifier, self.ssl_ocsp_expected_cert + ) + + # need another socket + con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) + con.request_ocsp() + con.connect((self.host, self.port)) + con.do_handshake() + con.shutdown() + return sslsock + + # pure ocsp validation + if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: + from .ocsp import OCSPVerifier + + o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) + if o.is_valid(): + return sslsock + else: + raise ConnectionError("ocsp validation error") + return sslsock + + +class UnixDomainSocketConnection(AbstractConnection): + "Manages UDS communication to and from a Redis server" + + def __init__(self, path="", socket_timeout=None, **kwargs): + super().__init__(**kwargs) + self.path = path + self.socket_timeout = socket_timeout + + def repr_pieces(self): + pieces = [("path", self.path), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + def _connect(self): + "Create a Unix domain socket connection" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(self.socket_connect_timeout) + try: + sock.connect(self.path) + except OSError: + # Prevent ResourceWarnings for unclosed sockets. + try: + sock.shutdown(socket.SHUT_RDWR) # ensure a clean close + except OSError: + pass + sock.close() + raise + sock.settimeout(self.socket_timeout) + return sock + + def _host_error(self): + return self.path + + +FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") + + +def to_bool(value): + if value is None or value == "": + return None + if isinstance(value, str) and value.upper() in FALSE_STRINGS: + return False + return bool(value) + + +def parse_ssl_verify_flags(value): + # flags are passed in as a string representation of a list, + # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN + verify_flags_str = value.replace("[", "").replace("]", "") + + verify_flags = [] + for flag in verify_flags_str.split(","): + flag = flag.strip() + if not hasattr(VerifyFlags, flag): + raise ValueError(f"Invalid ssl verify flag: {flag}") + verify_flags.append(getattr(VerifyFlags, flag)) + return verify_flags + + +URL_QUERY_ARGUMENT_PARSERS = { + "db": int, + "socket_timeout": float, + "socket_connect_timeout": float, + "socket_keepalive": to_bool, + "retry_on_timeout": to_bool, + "retry_on_error": list, + "max_connections": int, + "health_check_interval": int, + "ssl_check_hostname": to_bool, + "ssl_include_verify_flags": parse_ssl_verify_flags, + "ssl_exclude_verify_flags": parse_ssl_verify_flags, + "timeout": float, +} + + +def parse_url(url): + if not ( + url.startswith("redis://") + or url.startswith("rediss://") + or url.startswith("unix://") + ): + raise ValueError( + "Redis URL must specify one of the following " + "schemes (redis://, rediss://, unix://)" + ) + + url = urlparse(url) + kwargs = {} + + for name, value in parse_qs(url.query).items(): + if value and len(value) > 0: + value = unquote(value[0]) + parser = URL_QUERY_ARGUMENT_PARSERS.get(name) + if parser: + try: + kwargs[name] = parser(value) + except (TypeError, ValueError): + raise ValueError(f"Invalid value for '{name}' in connection URL.") + else: + kwargs[name] = value + + if url.username: + kwargs["username"] = unquote(url.username) + if url.password: + kwargs["password"] = unquote(url.password) + + # We only support redis://, rediss:// and unix:// schemes. + if url.scheme == "unix": + if url.path: + kwargs["path"] = unquote(url.path) + kwargs["connection_class"] = UnixDomainSocketConnection + + else: # implied: url.scheme in ("redis", "rediss"): + if url.hostname: + kwargs["host"] = unquote(url.hostname) + if url.port: + kwargs["port"] = int(url.port) + + # If there's a path argument, use it as the db argument if a + # querystring value wasn't specified + if url.path and "db" not in kwargs: + try: + kwargs["db"] = int(unquote(url.path).replace("/", "")) + except (AttributeError, ValueError): + pass + + if url.scheme == "rediss": + kwargs["connection_class"] = SSLConnection + + return kwargs + + +_CP = TypeVar("_CP", bound="ConnectionPool") + + +class ConnectionPoolInterface(ABC): + @abstractmethod + def get_protocol(self): + pass + + @abstractmethod + def reset(self): + pass + + @abstractmethod + @deprecated_args( + args_to_warn=["*"], + reason="Use get_connection() without args instead", + version="5.3.0", + ) + def get_connection( + self, command_name: Optional[str], *keys, **options + ) -> ConnectionInterface: + pass + + @abstractmethod + def get_encoder(self): + pass - def _wrap_socket_with_ssl(self, sock): - """ - Wraps the socket with SSL support. + @abstractmethod + def release(self, connection: ConnectionInterface): + pass - Args: - sock: The plain socket to wrap with SSL. + @abstractmethod + def disconnect(self, inuse_connections: bool = True): + pass - Returns: - An SSL wrapped socket. - """ - context = ssl.create_default_context() - context.check_hostname = self.check_hostname - context.verify_mode = self.cert_reqs - if self.ssl_include_verify_flags: - for flag in self.ssl_include_verify_flags: - context.verify_flags |= flag - if self.ssl_exclude_verify_flags: - for flag in self.ssl_exclude_verify_flags: - context.verify_flags &= ~flag - if self.certfile or self.keyfile: - context.load_cert_chain( - certfile=self.certfile, - keyfile=self.keyfile, - password=self.certificate_password, - ) - if ( - self.ca_certs is not None - or self.ca_path is not None - or self.ca_data is not None - ): - context.load_verify_locations( - cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data - ) - if self.ssl_min_version is not None: - context.minimum_version = self.ssl_min_version - if self.ssl_ciphers: - context.set_ciphers(self.ssl_ciphers) - if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: - raise RedisError("cryptography is not installed.") + @abstractmethod + def close(self): + pass - if self.ssl_validate_ocsp_stapled and self.ssl_validate_ocsp: - raise RedisError( - "Either an OCSP staple or pure OCSP connection must be validated " - "- not both." - ) + @abstractmethod + def set_retry(self, retry: Retry): + pass - sslsock = context.wrap_socket(sock, server_hostname=self.host) + @abstractmethod + def re_auth_callback(self, token: TokenInterface): + pass - # validation for the stapled case - if self.ssl_validate_ocsp_stapled: - import OpenSSL - from .ocsp import ocsp_staple_verifier +class MaintNotificationsAbstractConnectionPool: + """ + Abstract class for handling maintenance notifications logic. + This class is mixed into the ConnectionPool classes. - # if a context is provided use it - otherwise, a basic context - if self.ssl_ocsp_context is None: - staple_ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) - staple_ctx.use_certificate_file(self.certfile) - staple_ctx.use_privatekey_file(self.keyfile) - else: - staple_ctx = self.ssl_ocsp_context + This class is not intended to be used directly! - staple_ctx.set_ocsp_client_callback( - ocsp_staple_verifier, self.ssl_ocsp_expected_cert + All logic related to maintenance notifications and + connection pool handling is encapsulated in this class. + """ + + def __init__( + self, + maint_notifications_config: Optional[MaintNotificationsConfig] = None, + **kwargs, + ): + # Initialize maintenance notifications + is_protocol_supported = kwargs.get("protocol") in [3, "3"] + if maint_notifications_config is None and is_protocol_supported: + maint_notifications_config = MaintNotificationsConfig() + + if maint_notifications_config and maint_notifications_config.enabled: + if not is_protocol_supported: + raise RedisError( + "Maintenance notifications handlers on connection are only supported with RESP version 3" + ) + + self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( + self, maint_notifications_config ) - # need another socket - con = OpenSSL.SSL.Connection(staple_ctx, socket.socket()) - con.request_ocsp() - con.connect((self.host, self.port)) - con.do_handshake() - con.shutdown() - return sslsock + self._update_connection_kwargs_for_maint_notifications( + self._maint_notifications_pool_handler + ) + else: + self._maint_notifications_pool_handler = None - # pure ocsp validation - if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE: - from .ocsp import OCSPVerifier + @property + @abstractmethod + def connection_kwargs(self) -> Dict[str, Any]: + pass - o = OCSPVerifier(sslsock, self.host, self.port, self.ca_certs) - if o.is_valid(): - return sslsock - else: - raise ConnectionError("ocsp validation error") - return sslsock + @connection_kwargs.setter + @abstractmethod + def connection_kwargs(self, value: Dict[str, Any]): + pass + @abstractmethod + def _get_pool_lock(self) -> threading.RLock: + pass -class UnixDomainSocketConnection(AbstractConnection): - "Manages UDS communication to and from a Redis server" + @abstractmethod + def _get_free_connections(self) -> Iterable["MaintNotificationsAbstractConnection"]: + pass - def __init__(self, path="", socket_timeout=None, **kwargs): - super().__init__(**kwargs) - self.path = path - self.socket_timeout = socket_timeout + @abstractmethod + def _get_in_use_connections( + self, + ) -> Iterable["MaintNotificationsAbstractConnection"]: + pass - def repr_pieces(self): - pieces = [("path", self.path), ("db", self.db)] - if self.client_name: - pieces.append(("client_name", self.client_name)) - return pieces + def maint_notifications_enabled(self): + """ + Returns: + True if the maintenance notifications are enabled, False otherwise. + The maintenance notifications config is stored in the pool handler. + If the pool handler is not set, the maintenance notifications are not enabled. + """ + maint_notifications_config = ( + self._maint_notifications_pool_handler.config + if self._maint_notifications_pool_handler + else None + ) - def _connect(self): - "Create a Unix domain socket connection" - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.settimeout(self.socket_connect_timeout) - try: - sock.connect(self.path) - except OSError: - # Prevent ResourceWarnings for unclosed sockets. - try: - sock.shutdown(socket.SHUT_RDWR) # ensure a clean close - except OSError: - pass - sock.close() - raise - sock.settimeout(self.socket_timeout) - return sock + return maint_notifications_config and maint_notifications_config.enabled - def _host_error(self): - return self.path + def update_maint_notifications_config( + self, maint_notifications_config: MaintNotificationsConfig + ): + """ + Updates the maintenance notifications configuration. + This method should be called only if the pool was created + without enabling the maintenance notifications and + in a later point in time maintenance notifications + are requested to be enabled. + """ + if ( + self.maint_notifications_enabled() + and not maint_notifications_config.enabled + ): + raise ValueError( + "Cannot disable maintenance notifications after enabling them" + ) + # first update pool settings + if not self._maint_notifications_pool_handler: + self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( + self, maint_notifications_config + ) + else: + self._maint_notifications_pool_handler.config = maint_notifications_config + # then update connection kwargs and existing connections + self._update_connection_kwargs_for_maint_notifications( + self._maint_notifications_pool_handler + ) + self._update_maint_notifications_configs_for_connections( + self._maint_notifications_pool_handler + ) -FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") + def _update_connection_kwargs_for_maint_notifications( + self, maint_notifications_pool_handler: MaintNotificationsPoolHandler + ): + """ + Update the connection kwargs for all future connections. + """ + if not self.maint_notifications_enabled(): + return + self.connection_kwargs.update( + { + "maint_notifications_pool_handler": maint_notifications_pool_handler, + "maint_notifications_config": maint_notifications_pool_handler.config, + } + ) -def to_bool(value): - if value is None or value == "": - return None - if isinstance(value, str) and value.upper() in FALSE_STRINGS: - return False - return bool(value) + # Store original connection parameters for maintenance notifications. + if self.connection_kwargs.get("orig_host_address", None) is None: + # If orig_host_address is None it means we haven't + # configured the original values yet + self.connection_kwargs.update( + { + "orig_host_address": self.connection_kwargs.get("host"), + "orig_socket_timeout": self.connection_kwargs.get( + "socket_timeout", None + ), + "orig_socket_connect_timeout": self.connection_kwargs.get( + "socket_connect_timeout", None + ), + } + ) + def _update_maint_notifications_configs_for_connections( + self, maint_notifications_pool_handler: MaintNotificationsPoolHandler + ): + """Update the maintenance notifications config for all connections in the pool.""" + with self._get_pool_lock(): + for conn in self._get_free_connections(): + conn.set_maint_notifications_pool_handler_for_connection( + maint_notifications_pool_handler + ) + conn.maint_notifications_config = ( + maint_notifications_pool_handler.config + ) + conn.disconnect() + for conn in self._get_in_use_connections(): + conn.set_maint_notifications_pool_handler_for_connection( + maint_notifications_pool_handler + ) + conn.maint_notifications_config = ( + maint_notifications_pool_handler.config + ) + conn.mark_for_reconnect() -def parse_ssl_verify_flags(value): - # flags are passed in as a string representation of a list, - # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN - verify_flags_str = value.replace("[", "").replace("]", "") + def _should_update_connection( + self, + conn: "MaintNotificationsAbstractConnection", + matching_pattern: Literal[ + "connected_address", "configured_address", "notification_hash" + ] = "connected_address", + matching_address: Optional[str] = None, + matching_notification_hash: Optional[int] = None, + ) -> bool: + """ + Check if the connection should be updated based on the matching criteria. + """ + if matching_pattern == "connected_address": + if matching_address and conn.getpeername() != matching_address: + return False + elif matching_pattern == "configured_address": + if matching_address and conn.host != matching_address: + return False + elif matching_pattern == "notification_hash": + if ( + matching_notification_hash + and conn.maintenance_notification_hash != matching_notification_hash + ): + return False + return True + + def update_connection_settings( + self, + conn: "MaintNotificationsAbstractConnection", + state: Optional["MaintenanceState"] = None, + maintenance_notification_hash: Optional[int] = None, + host_address: Optional[str] = None, + relaxed_timeout: Optional[float] = None, + update_notification_hash: bool = False, + reset_host_address: bool = False, + reset_relaxed_timeout: bool = False, + ): + """ + Update the settings for a single connection. + """ + if state: + conn.maintenance_state = state - verify_flags = [] - for flag in verify_flags_str.split(","): - flag = flag.strip() - if not hasattr(VerifyFlags, flag): - raise ValueError(f"Invalid ssl verify flag: {flag}") - verify_flags.append(getattr(VerifyFlags, flag)) - return verify_flags + if update_notification_hash: + # update the notification hash only if requested + conn.maintenance_notification_hash = maintenance_notification_hash + if host_address is not None: + conn.set_tmp_settings(tmp_host_address=host_address) -URL_QUERY_ARGUMENT_PARSERS = { - "db": int, - "socket_timeout": float, - "socket_connect_timeout": float, - "socket_keepalive": to_bool, - "retry_on_timeout": to_bool, - "retry_on_error": list, - "max_connections": int, - "health_check_interval": int, - "ssl_check_hostname": to_bool, - "ssl_include_verify_flags": parse_ssl_verify_flags, - "ssl_exclude_verify_flags": parse_ssl_verify_flags, - "timeout": float, -} + if relaxed_timeout is not None: + conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout) + + if reset_relaxed_timeout or reset_host_address: + conn.reset_tmp_settings( + reset_host_address=reset_host_address, + reset_relaxed_timeout=reset_relaxed_timeout, + ) + conn.update_current_socket_timeout(relaxed_timeout) -def parse_url(url): - if not ( - url.startswith("redis://") - or url.startswith("rediss://") - or url.startswith("unix://") + def update_connections_settings( + self, + state: Optional["MaintenanceState"] = None, + maintenance_notification_hash: Optional[int] = None, + host_address: Optional[str] = None, + relaxed_timeout: Optional[float] = None, + matching_address: Optional[str] = None, + matching_notification_hash: Optional[int] = None, + matching_pattern: Literal[ + "connected_address", "configured_address", "notification_hash" + ] = "connected_address", + update_notification_hash: bool = False, + reset_host_address: bool = False, + reset_relaxed_timeout: bool = False, + include_free_connections: bool = True, ): - raise ValueError( - "Redis URL must specify one of the following " - "schemes (redis://, rediss://, unix://)" - ) - - url = urlparse(url) - kwargs = {} + """ + Update the settings for all matching connections in the pool. - for name, value in parse_qs(url.query).items(): - if value and len(value) > 0: - value = unquote(value[0]) - parser = URL_QUERY_ARGUMENT_PARSERS.get(name) - if parser: - try: - kwargs[name] = parser(value) - except (TypeError, ValueError): - raise ValueError(f"Invalid value for '{name}' in connection URL.") - else: - kwargs[name] = value + This method does not create new connections. + This method does not affect the connection kwargs. - if url.username: - kwargs["username"] = unquote(url.username) - if url.password: - kwargs["password"] = unquote(url.password) + :param state: The maintenance state to set for the connection. + :param maintenance_notification_hash: The hash of the maintenance notification + to set for the connection. + :param host_address: The host address to set for the connection. + :param relaxed_timeout: The relaxed timeout to set for the connection. + :param matching_address: The address to match for the connection. + :param matching_notification_hash: The notification hash to match for the connection. + :param matching_pattern: The pattern to match for the connection. + :param update_notification_hash: Whether to update the notification hash for the connection. + :param reset_host_address: Whether to reset the host address to the original address. + :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout. + :param include_free_connections: Whether to include free/available connections. + """ + with self._get_pool_lock(): + for conn in self._get_in_use_connections(): + if self._should_update_connection( + conn, + matching_pattern, + matching_address, + matching_notification_hash, + ): + self.update_connection_settings( + conn, + state=state, + maintenance_notification_hash=maintenance_notification_hash, + host_address=host_address, + relaxed_timeout=relaxed_timeout, + update_notification_hash=update_notification_hash, + reset_host_address=reset_host_address, + reset_relaxed_timeout=reset_relaxed_timeout, + ) - # We only support redis://, rediss:// and unix:// schemes. - if url.scheme == "unix": - if url.path: - kwargs["path"] = unquote(url.path) - kwargs["connection_class"] = UnixDomainSocketConnection + if include_free_connections: + for conn in self._get_free_connections(): + if self._should_update_connection( + conn, + matching_pattern, + matching_address, + matching_notification_hash, + ): + self.update_connection_settings( + conn, + state=state, + maintenance_notification_hash=maintenance_notification_hash, + host_address=host_address, + relaxed_timeout=relaxed_timeout, + update_notification_hash=update_notification_hash, + reset_host_address=reset_host_address, + reset_relaxed_timeout=reset_relaxed_timeout, + ) - else: # implied: url.scheme in ("redis", "rediss"): - if url.hostname: - kwargs["host"] = unquote(url.hostname) - if url.port: - kwargs["port"] = int(url.port) + def update_connection_kwargs( + self, + **kwargs, + ): + """ + Update the connection kwargs for all future connections. - # If there's a path argument, use it as the db argument if a - # querystring value wasn't specified - if url.path and "db" not in kwargs: - try: - kwargs["db"] = int(unquote(url.path).replace("/", "")) - except (AttributeError, ValueError): - pass + This method updates the connection kwargs for all future connections created by the pool. + Existing connections are not affected. + """ + self.connection_kwargs.update(kwargs) - if url.scheme == "rediss": - kwargs["connection_class"] = SSLConnection + def update_active_connections_for_reconnect( + self, + moving_address_src: Optional[str] = None, + ): + """ + Mark all active connections for reconnect. + This is used when a cluster node is migrated to a different address. - return kwargs + :param moving_address_src: The address of the node that is being moved. + """ + with self._get_pool_lock(): + for conn in self._get_in_use_connections(): + if self._should_update_connection( + conn, "connected_address", moving_address_src + ): + conn.mark_for_reconnect() + def disconnect_free_connections( + self, + moving_address_src: Optional[str] = None, + ): + """ + Disconnect all free/available connections. + This is used when a cluster node is migrated to a different address. -_CP = TypeVar("_CP", bound="ConnectionPool") + :param moving_address_src: The address of the node that is being moved. + """ + with self._get_pool_lock(): + for conn in self._get_free_connections(): + if self._should_update_connection( + conn, "connected_address", moving_address_src + ): + conn.disconnect() -class ConnectionPool: +class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface): """ Create a connection pool. ``If max_connections`` is set, then this object raises :py:class:`~redis.exceptions.ConnectionError` when the pool's @@ -1692,6 +2373,12 @@ class ConnectionPool: unix sockets. :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. + If ``maint_notifications_config`` is provided, the connection pool will support + maintenance notifications. + Maintenance notifications are supported only with RESP3. + If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3, + the maintenance notifications will be enabled by default. + Any additional keyword arguments are passed to the constructor of ``connection_class``. """ @@ -1750,6 +2437,7 @@ def __init__( connection_class=Connection, max_connections: Optional[int] = None, cache_factory: Optional[CacheFactoryInterface] = None, + maint_notifications_config: Optional[MaintNotificationsConfig] = None, **connection_kwargs, ): max_connections = max_connections or 2**31 @@ -1757,16 +2445,16 @@ def __init__( raise ValueError('"max_connections" must be a positive integer') self.connection_class = connection_class - self.connection_kwargs = connection_kwargs + self._connection_kwargs = connection_kwargs self.max_connections = max_connections self.cache = None self._cache_factory = cache_factory if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): - if self.connection_kwargs.get("protocol") not in [3, "3"]: + if self._connection_kwargs.get("protocol") not in [3, "3"]: raise RedisError("Client caching is only supported with RESP version 3") - cache = self.connection_kwargs.get("cache") + cache = self._connection_kwargs.get("cache") if cache is not None: if not isinstance(cache, CacheInterface): @@ -1778,29 +2466,13 @@ def __init__( self.cache = self._cache_factory.get_cache() else: self.cache = CacheFactory( - self.connection_kwargs.get("cache_config") + self._connection_kwargs.get("cache_config") ).get_cache() connection_kwargs.pop("cache", None) connection_kwargs.pop("cache_config", None) - if self.connection_kwargs.get( - "maint_notifications_pool_handler" - ) or self.connection_kwargs.get("maint_notifications_config"): - if self.connection_kwargs.get("protocol") not in [3, "3"]: - raise RedisError( - "Push handlers on connection are only supported with RESP version 3" - ) - config = self.connection_kwargs.get("maint_notifications_config", None) or ( - self.connection_kwargs.get("maint_notifications_pool_handler").config - if self.connection_kwargs.get("maint_notifications_pool_handler") - else None - ) - - if config and config.enabled: - self._update_connection_kwargs_for_maint_notifications() - - self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) + self._event_dispatcher = self._connection_kwargs.get("event_dispatcher", None) if self._event_dispatcher is None: self._event_dispatcher = EventDispatcher() @@ -1816,6 +2488,12 @@ def __init__( self._fork_lock = threading.RLock() self._lock = threading.RLock() + MaintNotificationsAbstractConnectionPool.__init__( + self, + maint_notifications_config=maint_notifications_config, + **connection_kwargs, + ) + self.reset() def __repr__(self) -> str: @@ -1826,76 +2504,21 @@ def __repr__(self) -> str: f"({conn_kwargs})>)>" ) - def get_protocol(self): - """ - Returns: - The RESP protocol version, or ``None`` if the protocol is not specified, - in which case the server default will be used. - """ - return self.connection_kwargs.get("protocol", None) - - def maint_notifications_pool_handler_enabled(self): - """ - Returns: - True if the maintenance notifications pool handler is enabled, False otherwise. - """ - maint_notifications_config = self.connection_kwargs.get( - "maint_notifications_config", None - ) - - return maint_notifications_config and maint_notifications_config.enabled - - def set_maint_notifications_pool_handler( - self, maint_notifications_pool_handler: MaintNotificationsPoolHandler - ): - self.connection_kwargs.update( - { - "maint_notifications_pool_handler": maint_notifications_pool_handler, - "maint_notifications_config": maint_notifications_pool_handler.config, - } - ) - self._update_connection_kwargs_for_maint_notifications() - - self._update_maint_notifications_configs_for_connections( - maint_notifications_pool_handler - ) + @property + def connection_kwargs(self) -> Dict[str, Any]: + return self._connection_kwargs - def _update_maint_notifications_configs_for_connections( - self, maint_notifications_pool_handler - ): - """Update the maintenance notifications config for all connections in the pool.""" - with self._lock: - for conn in self._available_connections: - conn.set_maint_notifications_pool_handler( - maint_notifications_pool_handler - ) - conn.maint_notifications_config = ( - maint_notifications_pool_handler.config - ) - for conn in self._in_use_connections: - conn.set_maint_notifications_pool_handler( - maint_notifications_pool_handler - ) - conn.maint_notifications_config = ( - maint_notifications_pool_handler.config - ) + @connection_kwargs.setter + def connection_kwargs(self, value: Dict[str, Any]): + self._connection_kwargs = value - def _update_connection_kwargs_for_maint_notifications(self): - """Store original connection parameters for maintenance notifications.""" - if self.connection_kwargs.get("orig_host_address", None) is None: - # If orig_host_address is None it means we haven't - # configured the original values yet - self.connection_kwargs.update( - { - "orig_host_address": self.connection_kwargs.get("host"), - "orig_socket_timeout": self.connection_kwargs.get( - "socket_timeout", None - ), - "orig_socket_connect_timeout": self.connection_kwargs.get( - "socket_connect_timeout", None - ), - } - ) + def get_protocol(self): + """ + Returns: + The RESP protocol version, or ``None`` if the protocol is not specified, + in which case the server default will be used. + """ + return self.connection_kwargs.get("protocol", None) def reset(self) -> None: self._created_connections = 0 @@ -1987,7 +2610,7 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": if ( connection.can_read() and self.cache is None - and not self.maint_notifications_pool_handler_enabled() + and not self.maint_notifications_enabled() ): raise ConnectionError("Connection has data") except (ConnectionError, TimeoutError, OSError): @@ -2059,7 +2682,7 @@ def disconnect(self, inuse_connections: bool = True) -> None: Disconnects connections in the pool If ``inuse_connections`` is True, disconnect connections that are - current in use, potentially by other threads. Otherwise only disconnect + currently in use, potentially by other threads. Otherwise only disconnect connections that are idle in the pool. """ self._checkpid() @@ -2100,185 +2723,16 @@ def re_auth_callback(self, token: TokenInterface): for conn in self._in_use_connections: conn.set_re_auth_token(token) - def _should_update_connection( - self, - conn: "Connection", - matching_pattern: Literal[ - "connected_address", "configured_address", "notification_hash" - ] = "connected_address", - matching_address: Optional[str] = None, - matching_notification_hash: Optional[int] = None, - ) -> bool: - """ - Check if the connection should be updated based on the matching criteria. - """ - if matching_pattern == "connected_address": - if matching_address and conn.getpeername() != matching_address: - return False - elif matching_pattern == "configured_address": - if matching_address and conn.host != matching_address: - return False - elif matching_pattern == "notification_hash": - if ( - matching_notification_hash - and conn.maintenance_notification_hash != matching_notification_hash - ): - return False - return True - - def update_connection_settings( - self, - conn: "Connection", - state: Optional["MaintenanceState"] = None, - maintenance_notification_hash: Optional[int] = None, - host_address: Optional[str] = None, - relaxed_timeout: Optional[float] = None, - update_notification_hash: bool = False, - reset_host_address: bool = False, - reset_relaxed_timeout: bool = False, - ): - """ - Update the settings for a single connection. - """ - if state: - conn.maintenance_state = state - - if update_notification_hash: - # update the notification hash only if requested - conn.maintenance_notification_hash = maintenance_notification_hash - - if host_address is not None: - conn.set_tmp_settings(tmp_host_address=host_address) - - if relaxed_timeout is not None: - conn.set_tmp_settings(tmp_relaxed_timeout=relaxed_timeout) - - if reset_relaxed_timeout or reset_host_address: - conn.reset_tmp_settings( - reset_host_address=reset_host_address, - reset_relaxed_timeout=reset_relaxed_timeout, - ) - - conn.update_current_socket_timeout(relaxed_timeout) - - def update_connections_settings( - self, - state: Optional["MaintenanceState"] = None, - maintenance_notification_hash: Optional[int] = None, - host_address: Optional[str] = None, - relaxed_timeout: Optional[float] = None, - matching_address: Optional[str] = None, - matching_notification_hash: Optional[int] = None, - matching_pattern: Literal[ - "connected_address", "configured_address", "notification_hash" - ] = "connected_address", - update_notification_hash: bool = False, - reset_host_address: bool = False, - reset_relaxed_timeout: bool = False, - include_free_connections: bool = True, - ): - """ - Update the settings for all matching connections in the pool. - - This method does not create new connections. - This method does not affect the connection kwargs. - - :param state: The maintenance state to set for the connection. - :param maintenance_notification_hash: The hash of the maintenance notification - to set for the connection. - :param host_address: The host address to set for the connection. - :param relaxed_timeout: The relaxed timeout to set for the connection. - :param matching_address: The address to match for the connection. - :param matching_notification_hash: The notification hash to match for the connection. - :param matching_pattern: The pattern to match for the connection. - :param update_notification_hash: Whether to update the notification hash for the connection. - :param reset_host_address: Whether to reset the host address to the original address. - :param reset_relaxed_timeout: Whether to reset the relaxed timeout to the original timeout. - :param include_free_connections: Whether to include free/available connections. - """ - with self._lock: - for conn in self._in_use_connections: - if self._should_update_connection( - conn, - matching_pattern, - matching_address, - matching_notification_hash, - ): - self.update_connection_settings( - conn, - state=state, - maintenance_notification_hash=maintenance_notification_hash, - host_address=host_address, - relaxed_timeout=relaxed_timeout, - update_notification_hash=update_notification_hash, - reset_host_address=reset_host_address, - reset_relaxed_timeout=reset_relaxed_timeout, - ) - - if include_free_connections: - for conn in self._available_connections: - if self._should_update_connection( - conn, - matching_pattern, - matching_address, - matching_notification_hash, - ): - self.update_connection_settings( - conn, - state=state, - maintenance_notification_hash=maintenance_notification_hash, - host_address=host_address, - relaxed_timeout=relaxed_timeout, - update_notification_hash=update_notification_hash, - reset_host_address=reset_host_address, - reset_relaxed_timeout=reset_relaxed_timeout, - ) - - def update_connection_kwargs( - self, - **kwargs, - ): - """ - Update the connection kwargs for all future connections. - - This method updates the connection kwargs for all future connections created by the pool. - Existing connections are not affected. - """ - self.connection_kwargs.update(kwargs) - - def update_active_connections_for_reconnect( - self, - moving_address_src: Optional[str] = None, - ): - """ - Mark all active connections for reconnect. - This is used when a cluster node is migrated to a different address. + def _get_pool_lock(self): + return self._lock - :param moving_address_src: The address of the node that is being moved. - """ + def _get_free_connections(self): with self._lock: - for conn in self._in_use_connections: - if self._should_update_connection( - conn, "connected_address", moving_address_src - ): - conn.mark_for_reconnect() - - def disconnect_free_connections( - self, - moving_address_src: Optional[str] = None, - ): - """ - Disconnect all free/available connections. - This is used when a cluster node is migrated to a different address. + return self._available_connections - :param moving_address_src: The address of the node that is being moved. - """ + def _get_in_use_connections(self): with self._lock: - for conn in self._available_connections: - if self._should_update_connection( - conn, "connected_address", moving_address_src - ): - conn.disconnect() + return self._in_use_connections async def _mock(self, error: RedisError): """ @@ -2391,7 +2845,7 @@ def make_connection(self): ) else: connection = self.connection_class(**self.connection_kwargs) - self._connections.append(connection) + self._connections.append(connection) return connection finally: if self._locked: @@ -2520,124 +2974,19 @@ def disconnect(self): pass self._locked = False - def update_connections_settings( - self, - state: Optional["MaintenanceState"] = None, - maintenance_notification_hash: Optional[int] = None, - relaxed_timeout: Optional[float] = None, - host_address: Optional[str] = None, - matching_address: Optional[str] = None, - matching_notification_hash: Optional[int] = None, - matching_pattern: Literal[ - "connected_address", "configured_address", "notification_hash" - ] = "connected_address", - update_notification_hash: bool = False, - reset_host_address: bool = False, - reset_relaxed_timeout: bool = False, - include_free_connections: bool = True, - ): - """ - Override base class method to work with BlockingConnectionPool's structure. - """ + def _get_free_connections(self): with self._lock: - if include_free_connections: - for conn in tuple(self._connections): - if self._should_update_connection( - conn, - matching_pattern, - matching_address, - matching_notification_hash, - ): - self.update_connection_settings( - conn, - state=state, - maintenance_notification_hash=maintenance_notification_hash, - host_address=host_address, - relaxed_timeout=relaxed_timeout, - update_notification_hash=update_notification_hash, - reset_host_address=reset_host_address, - reset_relaxed_timeout=reset_relaxed_timeout, - ) - else: - connections_in_queue = {conn for conn in self.pool.queue if conn} - for conn in self._connections: - if conn not in connections_in_queue: - if self._should_update_connection( - conn, - matching_pattern, - matching_address, - matching_notification_hash, - ): - self.update_connection_settings( - conn, - state=state, - maintenance_notification_hash=maintenance_notification_hash, - host_address=host_address, - relaxed_timeout=relaxed_timeout, - update_notification_hash=update_notification_hash, - reset_host_address=reset_host_address, - reset_relaxed_timeout=reset_relaxed_timeout, - ) - - def update_active_connections_for_reconnect( - self, - moving_address_src: Optional[str] = None, - ): - """ - Mark all active connections for reconnect. - This is used when a cluster node is migrated to a different address. + return {conn for conn in self.pool.queue if conn} - :param moving_address_src: The address of the node that is being moved. - """ + def _get_in_use_connections(self): with self._lock: + # free connections connections_in_queue = {conn for conn in self.pool.queue if conn} - for conn in self._connections: - if conn not in connections_in_queue: - if self._should_update_connection( - conn, - matching_pattern="connected_address", - matching_address=moving_address_src, - ): - conn.mark_for_reconnect() - - def disconnect_free_connections( - self, - moving_address_src: Optional[str] = None, - ): - """ - Disconnect all free/available connections. - This is used when a cluster node is migrated to a different address. - - :param moving_address_src: The address of the node that is being moved. - """ - with self._lock: - existing_connections = self.pool.queue - - for conn in existing_connections: - if conn: - if self._should_update_connection( - conn, "connected_address", moving_address_src - ): - conn.disconnect() - - def _update_maint_notifications_config_for_connections( - self, maint_notifications_config - ): - for conn in tuple(self._connections): - conn.maint_notifications_config = maint_notifications_config - - def _update_maint_notifications_configs_for_connections( - self, maint_notifications_pool_handler - ): - """Update the maintenance notifications config for all connections in the pool.""" - with self._lock: - for conn in tuple(self._connections): - conn.set_maint_notifications_pool_handler( - maint_notifications_pool_handler - ) - conn.maint_notifications_config = ( - maint_notifications_pool_handler.config - ) + # in self._connections we keep all created connections + # so the ones that are not in the queue are the in use ones + return { + conn for conn in self._connections if conn not in connections_in_queue + } def set_in_maintenance(self, in_maintenance: bool): """ diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index 37e4f93a3f..5b8b08c1be 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -32,9 +32,8 @@ def __str__(self): if TYPE_CHECKING: from redis.connection import ( - BlockingConnectionPool, - ConnectionInterface, - ConnectionPool, + MaintNotificationsAbstractConnection, + MaintNotificationsAbstractConnectionPool, ) @@ -501,7 +500,7 @@ def is_relaxed_timeouts_enabled(self) -> bool: return self.relaxed_timeout != -1 def get_endpoint_type( - self, host: str, connection: "ConnectionInterface" + self, host: str, connection: "MaintNotificationsAbstractConnection" ) -> EndpointType: """ Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command. @@ -558,7 +557,7 @@ def get_endpoint_type( class MaintNotificationsPoolHandler: def __init__( self, - pool: Union["ConnectionPool", "BlockingConnectionPool"], + pool: "MaintNotificationsAbstractConnectionPool", config: MaintNotificationsConfig, ) -> None: self.pool = pool @@ -567,9 +566,19 @@ def __init__( self._lock = threading.RLock() self.connection = None - def set_connection(self, connection: "ConnectionInterface"): + def set_connection(self, connection: "MaintNotificationsAbstractConnection"): self.connection = connection + def get_handler_for_connection(self): + # Copy all data that should be shared between connections + # but each connection should have its own pool handler + # since each connection can be in a different state + copy = MaintNotificationsPoolHandler(self.pool, self.config) + copy._processed_notifications = self._processed_notifications + copy._lock = self._lock + copy.connection = None + return copy + def remove_expired_notifications(self): with self._lock: for notification in tuple(self._processed_notifications): @@ -751,7 +760,9 @@ class MaintNotificationsConnectionHandler: } def __init__( - self, connection: "ConnectionInterface", config: MaintNotificationsConfig + self, + connection: "MaintNotificationsAbstractConnection", + config: MaintNotificationsConfig, ) -> None: self.connection = connection self.config = config diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 8b8e0cfc2c..29140b9e05 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -11,7 +11,10 @@ from redis import AuthenticationError, DataError, Redis, ResponseError from redis.auth.err import RequestTokenErr from redis.backoff import NoBackoff -from redis.connection import ConnectionInterface, ConnectionPool +from redis.connection import ( + ConnectionInterface, + ConnectionPool, +) from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ConnectionError, RedisError from redis.retry import Retry @@ -428,6 +431,7 @@ def re_auth_callback(token): def test_re_auth_pub_sub_in_resp3(self, credential_provider): mock_pubsub_connection = Mock(spec=ConnectionInterface) mock_pubsub_connection.get_protocol.return_value = 3 + mock_pubsub_connection.should_reconnect = Mock(return_value=False) mock_pubsub_connection.credential_provider = credential_provider mock_pubsub_connection.retry = Retry(NoBackoff(), 3) mock_another_connection = Mock(spec=ConnectionInterface) @@ -488,6 +492,7 @@ def re_auth_callback(token): def test_do_not_re_auth_pub_sub_in_resp2(self, credential_provider): mock_pubsub_connection = Mock(spec=ConnectionInterface) mock_pubsub_connection.get_protocol.return_value = 2 + mock_pubsub_connection.should_reconnect = Mock(return_value=False) mock_pubsub_connection.credential_provider = credential_provider mock_pubsub_connection.retry = Retry(NoBackoff(), 3) mock_another_connection = Mock(spec=ConnectionInterface) diff --git a/tests/test_maint_notifications.py b/tests/test_maint_notifications.py index 08ac15368f..85aa671390 100644 --- a/tests/test_maint_notifications.py +++ b/tests/test_maint_notifications.py @@ -2,7 +2,7 @@ from unittest.mock import Mock, call, patch, MagicMock import pytest -from redis.connection import ConnectionInterface +from redis.connection import ConnectionInterface, MaintNotificationsAbstractConnection from redis.maint_notifications import ( MaintenanceNotification, @@ -758,13 +758,16 @@ def __init__(self, resolved_ip): def getpeername(self): return (self.resolved_ip, 6379) - class MockConnection(ConnectionInterface): + class MockConnection(MaintNotificationsAbstractConnection, ConnectionInterface): def __init__(self, host, resolved_ip=None, is_ssl=False): self.host = host self.port = 6379 self._sock = MockSocket(resolved_ip) if resolved_ip else None self.__class__.__name__ = "SSLConnection" if is_ssl else "Connection" + def _get_socket(self): + return self._sock + def get_resolved_ip(self): # Call the actual method from AbstractConnection from redis.connection import AbstractConnection diff --git a/tests/test_maint_notifications_handling.py b/tests/test_maint_notifications_handling.py index 54b6e2dff7..556b63d7e1 100644 --- a/tests/test_maint_notifications_handling.py +++ b/tests/test_maint_notifications_handling.py @@ -7,8 +7,10 @@ from time import sleep from redis import Redis +from redis.cache import CacheConfig from redis.connection import ( AbstractConnection, + Connection, ConnectionPool, BlockingConnectionPool, MaintenanceState, @@ -68,7 +70,7 @@ def validate_in_use_connections_state( # and timeout is updated for connection in in_use_connections: if expected_should_reconnect != "any": - assert connection._should_reconnect == expected_should_reconnect + assert connection.should_reconnect() == expected_should_reconnect assert connection.host == expected_host_address assert connection.socket_timeout == expected_socket_timeout assert connection.socket_connect_timeout == expected_socket_connect_timeout @@ -78,13 +80,12 @@ def validate_in_use_connections_state( connection.orig_socket_connect_timeout == expected_orig_socket_connect_timeout ) - if connection._sock is not None: - assert connection._sock.gettimeout() == expected_current_socket_timeout - assert connection._sock.connected is True + conn_socket = connection._get_socket() + if conn_socket is not None: + assert conn_socket.gettimeout() == expected_current_socket_timeout + assert conn_socket.connected is True if expected_current_peername != "any": - assert ( - connection._sock.getpeername()[0] == expected_current_peername - ) + assert conn_socket.getpeername()[0] == expected_current_peername assert connection.maintenance_state == expected_state @staticmethod @@ -112,7 +113,7 @@ def validate_free_connections_state( connected_count = 0 for connection in free_connections: - assert connection._should_reconnect is False + assert connection.should_reconnect() is False assert connection.host == expected_host_address assert connection.socket_timeout == expected_socket_timeout assert connection.socket_connect_timeout == expected_socket_connect_timeout @@ -126,10 +127,11 @@ def validate_free_connections_state( if expected_state == MaintenanceState.NONE: assert connection.maintenance_notification_hash is None - if connection._sock is not None: - assert connection._sock.connected is True + conn_socket = connection._get_socket() + if conn_socket is not None: + assert conn_socket.connected is True if connected_to_tmp_address and tmp_address != "any": - assert connection._sock.getpeername()[0] == tmp_address + assert conn_socket.getpeername()[0] == tmp_address connected_count += 1 assert connected_count == should_be_connected_count @@ -201,7 +203,7 @@ def send(self, data): # Analyze the command and prepare appropriate response if b"HELLO" in data: - response = b"%7\r\n$6\r\nserver\r\n$5\r\nredis\r\n$7\r\nversion\r\n$5\r\n7.0.0\r\n$5\r\nproto\r\n:3\r\n$2\r\nid\r\n:1\r\n$4\r\nmode\r\n$10\r\nstandalone\r\n$4\r\nrole\r\n$6\r\nmaster\r\n$7\r\nmodules\r\n*0\r\n" + response = b"%7\r\n$6\r\nserver\r\n$5\r\nredis\r\n$7\r\nversion\r\n$5\r\n7.4.0\r\n$5\r\nproto\r\n:3\r\n$2\r\nid\r\n:1\r\n$4\r\nmode\r\n$10\r\nstandalone\r\n$4\r\nrole\r\n$6\r\nmaster\r\n$7\r\nmodules\r\n*0\r\n" self.pending_responses.append(response) elif b"MAINT_NOTIFICATIONS" in data and b"internal-ip" in data: # Simulate error response - activate it only for internal-ip tests @@ -302,6 +304,38 @@ def recv(self, bufsize): raise BlockingIOError(errno.EAGAIN, "Resource temporarily unavailable") + def recv_into(self, buffer, nbytes=0): + """ + Receive data from Redis and write it into the provided buffer. + Returns the number of bytes written. + + This method is used by the hiredis parser for efficient data reading. + """ + if self.closed: + raise ConnectionError("Socket is closed") + + # Use pending responses that were prepared when commands were sent + if self.pending_responses: + response = self.pending_responses.pop(0) + if b"MOVING" in response: + self.moving_sent = True + + # Determine how many bytes to write + if nbytes == 0: + nbytes = len(buffer) + + # Write data into the buffer (up to nbytes or response length) + bytes_to_write = min(len(response), nbytes, len(buffer)) + buffer[:bytes_to_write] = response[:bytes_to_write] + + return bytes_to_write + else: + # No data available - this should block or raise an exception + # For can_read checks, we should indicate no data is available + import errno + + raise BlockingIOError(errno.EAGAIN, "Resource temporarily unavailable") + def fileno(self): """Return a fake file descriptor for select/poll operations.""" return 1 # Fake file descriptor @@ -392,9 +426,10 @@ def teardown_method(self): def _get_client( self, pool_class, + connection_class=Connection, + enable_cache=False, max_connections=10, maint_notifications_config=None, - setup_pool_handler=False, ): """Helper method to create a pool and Redis client with maintenance notifications configuration. @@ -413,25 +448,21 @@ def _get_client( if maint_notifications_config is not None else self.config ) + pool_kwargs = {} + if enable_cache: + pool_kwargs = {"cache_config": CacheConfig()} test_pool = pool_class( + connection_class=connection_class, host=DEFAULT_ADDRESS.split(":")[0], port=int(DEFAULT_ADDRESS.split(":")[1]), max_connections=max_connections, protocol=3, # Required for maintenance notifications maint_notifications_config=config, + **pool_kwargs, ) test_redis_client = Redis(connection_pool=test_pool) - # Set up pool handler for moving notifications if requested - if setup_pool_handler: - pool_handler = MaintNotificationsPoolHandler( - test_redis_client.connection_pool, config - ) - test_redis_client.connection_pool.set_maint_notifications_pool_handler( - pool_handler - ) - return test_redis_client @@ -490,6 +521,9 @@ def test_handshake_failure_when_enabled(self): ) try: with pytest.raises(ResponseError): + # handshake should fail + # socket mock will return error when enabling maint notifications + # for internal-ip test_redis_client.set("hello", "world") finally: @@ -506,7 +540,13 @@ def _validate_connection_handlers(self, conn, pool_handler, config): assert parser_handler is not None assert hasattr(parser_handler, "__self__") assert hasattr(parser_handler, "__func__") - assert parser_handler.__self__ is pool_handler + assert parser_handler.__self__.connection is conn + assert parser_handler.__self__.pool is pool_handler.pool + assert parser_handler.__self__._lock is pool_handler._lock + assert ( + parser_handler.__self__._processed_notifications + is pool_handler._processed_notifications + ) assert parser_handler.__func__ is pool_handler.handle_notification.__func__ # Test that the maintenance handler function is correctly set @@ -576,36 +616,12 @@ def test_client_initialization(self): assert pool_handler.config == self.config conn = test_redis_client.connection_pool.get_connection() - assert conn._should_reconnect is False + + assert conn.should_reconnect() is False assert conn.orig_host_address == "localhost" assert conn.orig_socket_timeout is None - # Test that the node moving handler function is correctly set by - # comparing the underlying function and instance - parser_handler = conn._parser.node_moving_push_handler_func - assert parser_handler is not None - assert hasattr(parser_handler, "__self__") - assert hasattr(parser_handler, "__func__") - assert parser_handler.__self__ is pool_handler - assert parser_handler.__func__ is pool_handler.handle_notification.__func__ - - # Test that the maintenance handler function is correctly set - maintenance_handler = conn._parser.maintenance_push_handler_func - assert maintenance_handler is not None - assert hasattr(maintenance_handler, "__self__") - assert hasattr(maintenance_handler, "__func__") - # The maintenance handler should be bound to the connection's - # maintenance notification connection handler - assert ( - maintenance_handler.__self__ is conn._maint_notifications_connection_handler - ) - assert ( - maintenance_handler.__func__ - is conn._maint_notifications_connection_handler.handle_notification.__func__ - ) - - # Validate that the connection's maintenance handler has the same config object - assert conn._maint_notifications_connection_handler.config is self.config + self._validate_connection_handlers(conn, pool_handler, self.config) def test_maint_handler_init_for_existing_connections(self): """Test that maintenance notification handlers are properly set on existing and new connections @@ -630,13 +646,13 @@ def test_maint_handler_init_for_existing_connections(self): enabled_config = MaintNotificationsConfig( enabled=True, proactive_reconnect=True, relaxed_timeout=30 ) - pool_handler = MaintNotificationsPoolHandler( - test_redis_client.connection_pool, enabled_config - ) - test_redis_client.connection_pool.set_maint_notifications_pool_handler( - pool_handler + test_redis_client.connection_pool.update_maint_notifications_config( + enabled_config ) + pool_handler = ( + test_redis_client.connection_pool._maint_notifications_pool_handler + ) # Validate the existing connection after enabling maintenance notifications # Both existing and new connections should now have full handler setup self._validate_connection_handlers(existing_conn, pool_handler, enabled_config) @@ -644,6 +660,7 @@ def test_maint_handler_init_for_existing_connections(self): # Create a new connection and validate it has full handlers new_conn = test_redis_client.connection_pool.get_connection() self._validate_connection_handlers(new_conn, pool_handler, enabled_config) + self._validate_connection_handlers(existing_conn, pool_handler, enabled_config) # Clean up connections test_redis_client.connection_pool.release(existing_conn) @@ -665,11 +682,11 @@ def test_connection_pool_creation_with_maintenance_notifications(self, pool_clas == self.config ) # Pool should have maintenance notifications enabled - assert test_pool.maint_notifications_pool_handler_enabled() is True + assert test_pool.maint_notifications_enabled() is True # Create and set a pool handler - pool_handler = MaintNotificationsPoolHandler(test_pool, self.config) - test_pool.set_maint_notifications_pool_handler(pool_handler) + test_pool.update_maint_notifications_config(self.config) + pool_handler = test_pool._maint_notifications_pool_handler # Validate that the handler is properly set on the pool assert ( @@ -1056,9 +1073,7 @@ def test_moving_related_notifications_handling_integration(self, pool_class): 5. Tests both ConnectionPool and BlockingConnectionPool implementations """ # Create a pool and Redis client with maintenance notifications and pool handler - test_redis_client = self._get_client( - pool_class, max_connections=10, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=10) try: # Create several connections and return them in the pool @@ -1190,9 +1205,7 @@ def test_moving_none_notifications_handling_integration(self, pool_class): 5. Tests both ConnectionPool and BlockingConnectionPool implementations """ # Create a pool and Redis client with maintenance notifications and pool handler - test_redis_client = self._get_client( - pool_class, max_connections=10, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=10) try: # Create several connections and return them in the pool @@ -1339,9 +1352,7 @@ def test_create_new_conn_while_moving_not_expired(self, pool_class): 3. Pool configuration is properly applied to newly created connections """ # Create a pool and Redis client with maintenance notifications and pool handler - test_redis_client = self._get_client( - pool_class, max_connections=10, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=10) try: # Create several connections and return them in the pool @@ -1399,13 +1410,15 @@ def test_create_new_conn_while_moving_not_expired(self, pool_class): assert new_connection.host == AFTER_MOVING_ADDRESS.split(":")[0] assert new_connection.socket_timeout is self.config.relaxed_timeout # New connections should be connected to the temporary address - assert new_connection._sock is not None - assert new_connection._sock.connected is True + assert new_connection._get_socket() is not None + assert new_connection._get_socket().connected is True assert ( - new_connection._sock.getpeername()[0] + new_connection._get_socket().getpeername()[0] == AFTER_MOVING_ADDRESS.split(":")[0] ) - assert new_connection._sock.gettimeout() == self.config.relaxed_timeout + assert ( + new_connection._get_socket().gettimeout() == self.config.relaxed_timeout + ) finally: if hasattr(test_redis_client.connection_pool, "disconnect"): @@ -1422,9 +1435,7 @@ def test_create_new_conn_after_moving_expires(self, pool_class): 3. New connections don't inherit temporary settings """ # Create a pool and Redis client with maintenance notifications and pool handler - test_redis_client = self._get_client( - pool_class, max_connections=10, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=10) try: # Create several connections and return them in the pool @@ -1465,10 +1476,10 @@ def test_create_new_conn_after_moving_expires(self, pool_class): assert new_connection.orig_host_address == DEFAULT_ADDRESS.split(":")[0] assert new_connection.orig_socket_timeout is None # New connections should be connected to the original address - assert new_connection._sock is not None - assert new_connection._sock.connected is True + assert new_connection._get_socket() is not None + assert new_connection._get_socket().connected is True # Socket timeout should be None (original timeout) - assert new_connection._sock.gettimeout() is None + assert new_connection._get_socket().gettimeout() is None finally: if hasattr(test_redis_client.connection_pool, "disconnect"): @@ -1489,9 +1500,7 @@ def test_receive_migrated_after_moving(self, pool_class): it should not decrease timeouts (future refactoring consideration). """ # Create a pool and Redis client with maintenance notifications and pool handler - test_redis_client = self._get_client( - pool_class, max_connections=10, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=10) try: # Create several connections and return them in the pool @@ -1575,8 +1584,8 @@ def test_receive_migrated_after_moving(self, pool_class): # Note: New connections may not inherit the exact relaxed timeout value # but they should have the temporary host address # New connections should be connected - if connection._sock is not None: - assert connection._sock.connected is True + if connection._get_socket() is not None: + assert connection._get_socket().connected is True # Release the new connections for connection in new_connections: @@ -1597,9 +1606,7 @@ def test_overlapping_moving_notifications(self, pool_class): Ensures that the second MOVING notification updates the pool and connections as expected, and that expiry/cleanup works. """ global AFTER_MOVING_ADDRESS - test_redis_client = self._get_client( - pool_class, max_connections=5, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=5) try: # Create and release some connections in_use_connections = [] @@ -1708,7 +1715,6 @@ def test_overlapping_moving_notifications(self, pool_class): expected_current_socket_timeout=self.config.relaxed_timeout, expected_current_peername=orig_after_moving.split(":")[0], ) - # print(test_redis_client.connection_pool._available_connections) Helpers.validate_free_connections_state( test_redis_client.connection_pool, should_be_connected_count=1, @@ -1751,9 +1757,7 @@ def test_thread_safety_concurrent_notification_handling(self, pool_class): """ import threading - test_redis_client = self._get_client( - pool_class, max_connections=5, setup_pool_handler=True - ) + test_redis_client = self._get_client(pool_class, max_connections=5) results = [] errors = [] @@ -1790,8 +1794,18 @@ def worker(idx): if hasattr(test_redis_client.connection_pool, "disconnect"): test_redis_client.connection_pool.disconnect() - @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) - def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): + @pytest.mark.parametrize( + "pool_class,enable_cache", + [ + (ConnectionPool, False), + (ConnectionPool, True), + (BlockingConnectionPool, False), + (BlockingConnectionPool, True), + ], + ) + def test_moving_migrating_migrated_moved_state_transitions( + self, pool_class, enable_cache + ): """ Test moving configs are not lost if the per connection notifications get picked up after moving is handled. Sequence of notifications: MOVING, MIGRATING, MIGRATED, FAILING_OVER, FAILED_OVER, MOVED. @@ -1800,17 +1814,18 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): """ # Setup test_redis_client = self._get_client( - pool_class, max_connections=5, setup_pool_handler=True + pool_class, + max_connections=5, + enable_cache=enable_cache, ) pool = test_redis_client.connection_pool - pool_handler = pool.connection_kwargs["maint_notifications_pool_handler"] # Create and release some connections in_use_connections = [] for _ in range(3): in_use_connections.append(pool.get_connection()) - pool_handler.set_connection(in_use_connections[0]) + pool_handler = in_use_connections[0]._maint_notifications_pool_handler while len(in_use_connections) > 0: pool.release(in_use_connections.pop()) @@ -2019,10 +2034,8 @@ def test_migrating_after_moving_multiple_proxies(self, pool_class): protocol=3, # Required for maintenance notifications maint_notifications_config=self.config, ) - pool.set_maint_notifications_pool_handler( - MaintNotificationsPoolHandler(pool, self.config) - ) - pool_handler = pool.connection_kwargs["maint_notifications_pool_handler"] + + pool_handler = pool._maint_notifications_pool_handler # Create and release some connections key1 = "1.2.3.4"