Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,7 @@ async def run(
*,
exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
poll_timeout: float = 1.0,
pubsub = None
) -> None:
"""Process pub/sub messages using registered callbacks.

Expand All @@ -1215,9 +1216,14 @@ async def run(
await self.connect()
while True:
try:
await self.get_message(
ignore_subscribe_messages=True, timeout=poll_timeout
)
if pubsub is None:
await self.get_message(
ignore_subscribe_messages=True, timeout=poll_timeout
)
else:
await pubsub.get_message(
ignore_subscribe_messages=True, timeout=poll_timeout
)
except asyncio.CancelledError:
raise
except BaseException as e:
Expand Down
135 changes: 133 additions & 2 deletions redis/asyncio/multidb/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable

from redis.asyncio.client import PubSubHandler
from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
from redis.asyncio.multidb.database import AsyncDatabase, Databases
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
Expand All @@ -10,7 +11,7 @@
from redis.background import BackgroundScheduler
from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands
from redis.multidb.exception import NoValidDatabaseException
from redis.typing import KeyT
from redis.typing import KeyT, EncodableT, ChannelT


class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands):
Expand Down Expand Up @@ -222,6 +223,17 @@ async def transaction(
watch_delay=watch_delay,
)

async def pubsub(self, **kwargs):
"""
Return a Publish/Subscribe object. With this object, you can
subscribe to channels and listen for messages that get published to
them.
"""
if not self.initialized:
await self.initialize()

return PubSub(self, **kwargs)

async def _check_databases_health(
self,
on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None,
Expand Down Expand Up @@ -340,4 +352,123 @@ async def execute(self) -> List[Any]:
try:
return await self._client.command_executor.execute_pipeline(tuple(self._command_stack))
finally:
await self.reset()
await self.reset()

class PubSub:
"""
PubSub object for multi database client.
"""
def __init__(self, client: MultiDBClient, **kwargs):
"""Initialize the PubSub object for a multi-database client.

Args:
client: MultiDBClient instance to use for pub/sub operations
**kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
"""

self._client = client
self._client.command_executor.pubsub(**kwargs)

async def __aenter__(self) -> "PubSub":
return self

async def __aexit__(self, exc_type, exc_value, traceback) -> None:
await self.aclose()

async def aclose(self):
return await self._client.command_executor.execute_pubsub_method('aclose')

@property
def subscribed(self) -> bool:
return self._client.command_executor.active_pubsub.subscribed

async def execute_command(self, *args: EncodableT):
return await self._client.command_executor.execute_pubsub_method('execute_command', *args)

async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
expect a pattern name as the key and a callable as the value. A
pattern's callable will be invoked automatically when a message is
received on that pattern rather than producing a message via
``listen()``.
"""
return await self._client.command_executor.execute_pubsub_method(
'psubscribe',
*args,
**kwargs
)

async def punsubscribe(self, *args: ChannelT):
"""
Unsubscribe from the supplied patterns. If empty, unsubscribe from
all patterns.
"""
return await self._client.command_executor.execute_pubsub_method(
'punsubscribe',
*args
)

async def subscribe(self, *args: ChannelT, **kwargs: Callable):
"""
Subscribe to channels. Channels supplied as keyword arguments expect
a channel name as the key and a callable as the value. A channel's
callable will be invoked automatically when a message is received on
that channel rather than producing a message via ``listen()`` or
``get_message()``.
"""
return await self._client.command_executor.execute_pubsub_method(
'subscribe',
*args,
**kwargs
)

async def unsubscribe(self, *args):
"""
Unsubscribe from the supplied channels. If empty, unsubscribe from
all channels
"""
return await self._client.command_executor.execute_pubsub_method(
'unsubscribe',
*args
)

async def get_message(
self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
):
"""
Get the next message if one is available, otherwise None.

If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number or None to wait indefinitely.
"""
return await self._client.command_executor.execute_pubsub_method(
'get_message',
ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout
)

async def run(
self,
*,
exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
poll_timeout: float = 1.0,
) -> None:
"""Process pub/sub messages using registered callbacks.

This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
redis-py, but it is a coroutine. To launch it as a separate task, use
``asyncio.create_task``:

>>> task = asyncio.create_task(pubsub.run())

To shut it down, use asyncio cancellation:

>>> task.cancel()
>>> await task
"""
return await self._client.command_executor.execute_pubsub_run(
exception_handler=exception_handler,
sleep_time=poll_timeout,
pubsub=self
)
18 changes: 9 additions & 9 deletions redis/asyncio/multidb/command_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import abstractmethod
from asyncio import iscoroutinefunction
from datetime import datetime
from typing import List, Optional, Callable, Any, Union, Awaitable

Expand Down Expand Up @@ -178,14 +179,10 @@ def failover_strategy(self) -> AsyncFailoverStrategy:
def command_retry(self) -> Retry:
return self._command_retry

async def pubsub(self, **kwargs):
async def callback():
if self._active_pubsub is None:
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
self._active_pubsub_kwargs = kwargs
return None

return await self._execute_with_failure_detection(callback)
def pubsub(self, **kwargs):
if self._active_pubsub is None:
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
self._active_pubsub_kwargs = kwargs

async def execute_command(self, *args, **options):
async def callback():
Expand Down Expand Up @@ -225,7 +222,10 @@ async def callback():
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
async def callback():
method = getattr(self.active_pubsub, method_name)
return await method(*args, **kwargs)
if iscoroutinefunction(method):
return await method(*args, **kwargs)
else:
return method(*args, **kwargs)

return await self._execute_with_failure_detection(callback, *args)

Expand Down
8 changes: 4 additions & 4 deletions redis/multidb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,6 @@ def __init__(self, client: MultiDBClient, **kwargs):
def __enter__(self) -> "PubSub":
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
self.reset()

def __del__(self) -> None:
try:
# if this object went out of scope prior to shutting down
Expand All @@ -350,7 +347,7 @@ def __del__(self) -> None:
pass

def reset(self) -> None:
pass
return self._client.command_executor.execute_pubsub_method('reset')

def close(self) -> None:
self.reset()
Expand All @@ -359,6 +356,9 @@ def close(self) -> None:
def subscribed(self) -> bool:
return self._client.command_executor.active_pubsub.subscribed

def execute_command(self, *args):
return self._client.command_executor.execute_pubsub_method('execute_command', *args)

def psubscribe(self, *args, **kwargs):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
Expand Down
43 changes: 42 additions & 1 deletion tests/test_asyncio/test_scenario/test_active_active.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
import logging
from time import sleep

Expand Down Expand Up @@ -186,4 +187,44 @@ async def callback(pipe: Pipeline):
# Execute transaction until database failover
while not listener.is_changed_flag:
await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3']
await asyncio.sleep(0.5)
await asyncio.sleep(0.5)

@pytest.mark.asyncio
@pytest.mark.parametrize(
"r_multi_db",
[{"failure_threshold": 2}],
indirect=True
)
@pytest.mark.timeout(50)
async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client):
r_multi_db, listener, config = r_multi_db

event = asyncio.Event()
asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event))

data = json.dumps({'message': 'test'})
messages_count = 0

async def handler(message):
nonlocal messages_count
messages_count += 1

pubsub = await r_multi_db.pubsub()

# Assign a handler and run in a separate thread.
await pubsub.subscribe(**{'test-channel': handler})
task = asyncio.create_task(pubsub.run(poll_timeout=0.1))

# Execute publish before network failure
while not event.is_set():
await r_multi_db.publish('test-channel', data)
await asyncio.sleep(0.5)

# Execute publish until database failover
while not listener.is_changed_flag:
await r_multi_db.publish('test-channel', data)
await asyncio.sleep(0.5)

task.cancel()
await pubsub.unsubscribe('test-channel') is True
assert messages_count > 1