Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 173 additions & 8 deletions redis/_parsers/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import logging
import sys
from abc import ABC
from asyncio import IncompleteReadError, StreamReader, TimeoutError
from typing import Callable, List, Optional, Protocol, Union
from typing import Awaitable, Callable, List, Optional, Protocol, Union

from redis.maintenance_events import (
MaintenanceEvent,
NodeFailedOverEvent,
NodeFailingOverEvent,
NodeMigratedEvent,
NodeMigratingEvent,
NodeMovingEvent,
)

if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
Expand Down Expand Up @@ -50,6 +60,8 @@
"Client sent AUTH, but no password is set": AuthenticationError,
}

logger = logging.getLogger(__name__)


class BaseParser(ABC):
EXCEPTION_CLASSES = {
Expand Down Expand Up @@ -158,48 +170,195 @@ async def read_response(
raise NotImplementedError()


_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]
class MaintenanceNotificationsParser:
"""Protocol defining maintenance push notification parsing functionality"""

@staticmethod
def parse_maintenance_start_msg(response, notification_type):
# Expected message format is: <event_type> <seq_number> <time>
id = response[1]
ttl = response[2]
return notification_type(id, ttl)

@staticmethod
def parse_maintenance_completed_msg(response, notification_type):
# Expected message format is: <event_type> <seq_number>
id = response[1]
return notification_type(id)

@staticmethod
def parse_moving_msg(response):
# Expected message format is: MOVING <seq_number> <time> <endpoint>
id = response[1]
ttl = response[2]
if response[3] in [b"null", "null"]:
host, port = None, None
else:
value = response[3]
if isinstance(value, bytes):
value = value.decode()
host, port = value.split(":")
port = int(port) if port is not None else None

return NodeMovingEvent(id, host, port, ttl)


_INVALIDATION_MESSAGE = "invalidate"
_MOVING_MESSAGE = "MOVING"
_MIGRATING_MESSAGE = "MIGRATING"
_MIGRATED_MESSAGE = "MIGRATED"
_FAILING_OVER_MESSAGE = "FAILING_OVER"
_FAILED_OVER_MESSAGE = "FAILED_OVER"

_MAINTENANCE_MESSAGES = (
_MIGRATING_MESSAGE,
_MIGRATED_MESSAGE,
_FAILING_OVER_MESSAGE,
_FAILED_OVER_MESSAGE,
)

MSG_TYPE_TO_EVENT_PARSER_MAPPING: dict[str, tuple[type[MaintenanceEvent], Callable]] = {
Copy link

Copilot AI Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation uses Callable without specifying the signature, making it unclear what parameters the parser functions should accept. Consider using Callable[[List[Any]], MaintenanceEvent] or a Protocol for better type safety.

Copilot uses AI. Check for mistakes.

_MIGRATING_MESSAGE: (
NodeMigratingEvent,
MaintenanceNotificationsParser.parse_maintenance_start_msg,
),
_MIGRATED_MESSAGE: (
NodeMigratedEvent,
MaintenanceNotificationsParser.parse_maintenance_completed_msg,
),
_FAILING_OVER_MESSAGE: (
NodeFailingOverEvent,
MaintenanceNotificationsParser.parse_maintenance_start_msg,
),
_FAILED_OVER_MESSAGE: (
NodeFailedOverEvent,
MaintenanceNotificationsParser.parse_maintenance_completed_msg,
),
_MOVING_MESSAGE: (
NodeMovingEvent,
MaintenanceNotificationsParser.parse_moving_msg,
),
}


class PushNotificationsParser(Protocol):
"""Protocol defining RESP3-specific parsing functionality"""

pubsub_push_handler_func: Callable
invalidation_push_handler_func: Optional[Callable] = None
node_moving_push_handler_func: Optional[Callable] = None
maintenance_push_handler_func: Optional[Callable] = None

def handle_pubsub_push_response(self, response):
"""Handle pubsub push responses"""
raise NotImplementedError()

def handle_push_response(self, response, **kwargs):
if response[0] not in _INVALIDATION_MESSAGE:
msg_type = response[0]
if isinstance(msg_type, bytes):
msg_type = msg_type.decode()

if msg_type not in (
_INVALIDATION_MESSAGE,
*_MAINTENANCE_MESSAGES,
_MOVING_MESSAGE,
):
return self.pubsub_push_handler_func(response)
if self.invalidation_push_handler_func:
return self.invalidation_push_handler_func(response)

try:
if (
msg_type == _INVALIDATION_MESSAGE
and self.invalidation_push_handler_func
):
return self.invalidation_push_handler_func(response)

if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
parser_function = MSG_TYPE_TO_EVENT_PARSER_MAPPING[msg_type][1]

notification = parser_function(response)
return self.node_moving_push_handler_func(notification)

if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
parser_function = MSG_TYPE_TO_EVENT_PARSER_MAPPING[msg_type][1]
notification_type = MSG_TYPE_TO_EVENT_PARSER_MAPPING[msg_type][0]
notification = parser_function(response, notification_type)

if notification is not None:
return self.maintenance_push_handler_func(notification)
except Exception as e:
logger.error(
"Error handling {} message ({}): {}".format(msg_type, response, e)
)

return None

def set_pubsub_push_handler(self, pubsub_push_handler_func):
self.pubsub_push_handler_func = pubsub_push_handler_func

def set_invalidation_push_handler(self, invalidation_push_handler_func):
self.invalidation_push_handler_func = invalidation_push_handler_func

def set_node_moving_push_handler(self, node_moving_push_handler_func):
self.node_moving_push_handler_func = node_moving_push_handler_func

def set_maintenance_push_handler(self, maintenance_push_handler_func):
self.maintenance_push_handler_func = maintenance_push_handler_func


class AsyncPushNotificationsParser(Protocol):
"""Protocol defining async RESP3-specific parsing functionality"""

pubsub_push_handler_func: Callable
invalidation_push_handler_func: Optional[Callable] = None
node_moving_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None
maintenance_push_handler_func: Optional[Callable[..., Awaitable[None]]] = None

async def handle_pubsub_push_response(self, response):
"""Handle pubsub push responses asynchronously"""
raise NotImplementedError()

async def handle_push_response(self, response, **kwargs):
"""Handle push responses asynchronously"""
if response[0] not in _INVALIDATION_MESSAGE:

msg_type = response[0]
if isinstance(msg_type, bytes):
msg_type = msg_type.decode()

if msg_type not in (
_INVALIDATION_MESSAGE,
*_MAINTENANCE_MESSAGES,
_MOVING_MESSAGE,
):
return await self.pubsub_push_handler_func(response)
if self.invalidation_push_handler_func:
return await self.invalidation_push_handler_func(response)

try:
if (
msg_type == _INVALIDATION_MESSAGE
and self.invalidation_push_handler_func
):
return await self.invalidation_push_handler_func(response)

if isinstance(msg_type, bytes):
msg_type = msg_type.decode()

if msg_type == _MOVING_MESSAGE and self.node_moving_push_handler_func:
parser_function = MSG_TYPE_TO_EVENT_PARSER_MAPPING[msg_type][1]
notification = parser_function(response)
return await self.node_moving_push_handler_func(notification)

if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
parser_function = MSG_TYPE_TO_EVENT_PARSER_MAPPING[msg_type][1]
notification_type = MSG_TYPE_TO_EVENT_PARSER_MAPPING[msg_type][0]
notification = parser_function(response, notification_type)

if notification is not None:
return await self.maintenance_push_handler_func(notification)
except Exception as e:
logger.error(
"Error handling {} message ({}): {}".format(msg_type, response, e)
)

return None

def set_pubsub_push_handler(self, pubsub_push_handler_func):
"""Set the pubsub push handler function"""
Expand All @@ -209,6 +368,12 @@ def set_invalidation_push_handler(self, invalidation_push_handler_func):
"""Set the invalidation push handler function"""
self.invalidation_push_handler_func = invalidation_push_handler_func

def set_node_moving_push_handler(self, node_moving_push_handler_func):
self.node_moving_push_handler_func = node_moving_push_handler_func

def set_maintenance_push_handler(self, maintenance_push_handler_func):
self.maintenance_push_handler_func = maintenance_push_handler_func


class _AsyncRESPBase(AsyncBaseParser):
"""Base class for async resp parsing"""
Expand Down
26 changes: 16 additions & 10 deletions redis/_parsers/hiredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(self, socket_read_size):
self.socket_read_size = socket_read_size
self._buffer = bytearray(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.node_moving_push_handler_func = None
self.maintenance_push_handler_func = None
self.invalidation_push_handler_func = None
self._hiredis_PushNotificationType = None

Expand Down Expand Up @@ -141,12 +143,15 @@ def read_response(self, disable_decoding=False, push_request=False):
response, self._hiredis_PushNotificationType
):
response = self.handle_push_response(response)
if not push_request:
return self.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:

# if this is a push request return the push response
if push_request:
return response

return self.read_response(
disable_decoding=disable_decoding,
push_request=push_request,
)
return response

if disable_decoding:
Expand All @@ -169,12 +174,13 @@ def read_response(self, disable_decoding=False, push_request=False):
response, self._hiredis_PushNotificationType
):
response = self.handle_push_response(response)
if not push_request:
return self.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
if push_request:
return response
return self.read_response(
disable_decoding=disable_decoding,
push_request=push_request,
)

elif (
isinstance(response, list)
and response
Expand Down
16 changes: 11 additions & 5 deletions redis/_parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.node_moving_push_handler_func = None
self.maintenance_push_handler_func = None
self.invalidation_push_handler_func = None

def handle_pubsub_push_response(self, response):
Expand Down Expand Up @@ -117,17 +119,21 @@ def _read_response(self, disable_decoding=False, push_request=False):
for _ in range(int(response))
]
response = self.handle_push_response(response)
if not push_request:
return self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:

# if this is a push request return the push response
if push_request:
return response

return self._read_response(
disable_decoding=disable_decoding,
push_request=push_request,
)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)

return response


Expand Down
Loading