diff --git a/pulsar/__init__.py b/pulsar/__init__.py index fad33cd..48ff664 100644 --- a/pulsar/__init__.py +++ b/pulsar/__init__.py @@ -1576,7 +1576,7 @@ def redeliver_unacknowledged_messages(self): """ self._consumer.redeliver_unacknowledged_messages() - def seek(self, messageid): + def seek(self, messageid: Union[MessageId, _pulsar.MessageId, int]): """ Reset the subscription associated with this consumer to a specific message id or publish timestamp. The message id can either be a specific message or represent the first or last messages in the topic. @@ -1586,10 +1586,10 @@ def seek(self, messageid): Parameters ---------- - messageid: + messageid: MessageId, _pulsar.MessageId or int The message id for seek, OR an integer event time to seek to """ - self._consumer.seek(messageid) + self._consumer.seek(_seek_arg_convert(messageid)) def close(self): """ @@ -1745,7 +1745,7 @@ def has_message_available(self): """ return self._reader.has_message_available(); - def seek(self, messageid): + def seek(self, messageid: Union[MessageId, _pulsar.MessageId, int]): """ Reset this reader to a specific message id or publish timestamp. The message id can either be a specific message or represent the first or last messages in the topic. @@ -1755,10 +1755,10 @@ def seek(self, messageid): Parameters ---------- - messageid: + messageid: MessageId, _pulsar.MessageId or int The message id for seek, OR an integer event time to seek to """ - self._reader.seek(messageid) + self._reader.seek(_seek_arg_convert(messageid)) def close(self): """ @@ -1829,3 +1829,11 @@ def wrapper(consumer, msg): m._schema = schema listener(c, m) return wrapper + +def _seek_arg_convert(seek_arg): + if isinstance(seek_arg, MessageId): + return seek_arg._msg_id + elif isinstance(seek_arg, (_pulsar.MessageId, int)): + return seek_arg + else: + raise ValueError(f"invalid seek_arg type {type(seek_arg)}") diff --git a/tests/pulsar_test.py b/tests/pulsar_test.py index a062bc1..4d5dcb3 100755 --- a/tests/pulsar_test.py +++ b/tests/pulsar_test.py @@ -1019,11 +1019,22 @@ def test_seek(self): msg = consumer.receive(TM) self.assertEqual(msg.data(), b"hello-0") + # seek with wrong type + with self.assertRaises(ValueError, msg="invalid seek_arg type "): + consumer.seek(1.0) + # seek on messageId consumer.seek(ids[50]) msg = consumer.receive(TM) self.assertEqual(msg.data(), b"hello-51") + # seek on a user provided MessageId + msg_id = MessageId(ledger_id=ids[60].ledger_id(), + entry_id=ids[60].entry_id()) + consumer.seek(msg_id) + msg = consumer.receive(TM) + self.assertEqual(msg.data(), b"hello-61") + # ditto, but seek on timestamp consumer.seek(timestamps[42]) msg = consumer.receive(TM) @@ -1034,6 +1045,10 @@ def test_seek(self): with self.assertRaises(pulsar.Timeout): reader.read_next(100) + # seek with wrong type + with self.assertRaises(ValueError, msg="invalid seek_arg type "): + consumer.seek(1.0) + # earliest reader.seek(MessageId.earliest) msg = reader.read_next(TM) @@ -1048,6 +1063,13 @@ def test_seek(self): msg = reader.read_next(TM) self.assertEqual(msg.data(), b"hello-35") + # seek on a user provided MessageId + msg_id = MessageId(ledger_id=ids[44].ledger_id(), + entry_id=ids[44].entry_id()) + reader.seek(msg_id) + msg = reader.read_next(TM) + self.assertEqual(msg.data(), b"hello-45") + # seek on timestamp reader.seek(timestamps[79]) msg = reader.read_next(TM)