|
11 | 11 | from datetime import timedelta
|
12 | 12 | from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional
|
13 | 13 |
|
| 14 | +from airbyte_cdk.models import ( |
| 15 | + AirbyteStateBlob, |
| 16 | + AirbyteStateMessage, |
| 17 | + AirbyteStateType, |
| 18 | + AirbyteStreamState, |
| 19 | + StreamDescriptor, |
| 20 | +) |
14 | 21 | from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
|
15 | 22 | from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import (
|
16 | 23 | Timer,
|
@@ -48,7 +55,7 @@ class ConcurrentPerPartitionCursor(Cursor):
|
48 | 55 | Manages state per partition when a stream has many partitions, preventing data loss or duplication.
|
49 | 56 |
|
50 | 57 | Attributes:
|
51 |
| - DEFAULT_MAX_PARTITIONS_NUMBER (int): Maximum number of partitions to retain in memory (default is 10,000). |
| 58 | + DEFAULT_MAX_PARTITIONS_NUMBER (int): Maximum number of partitions to retain in memory (default is 10,000). This limit needs to be higher than the number of threads we might enqueue (which is represented by ThreadPoolManager.DEFAULT_MAX_QUEUE_SIZE). If not, we could have partitions that have been generated and submitted to the ThreadPool but got deleted from the ConcurrentPerPartitionCursor and when closing them, it will generate KeyError. |
52 | 59 |
|
53 | 60 | - **Partition Limitation Logic**
|
54 | 61 | Ensures the number of tracked partitions does not exceed the specified limit to prevent memory overuse. Oldest partitions are removed when the limit is reached.
|
@@ -128,6 +135,7 @@ def __init__(
|
128 | 135 |
|
129 | 136 | # FIXME this is a temporary field the time of the migration from declarative cursors to concurrent ones
|
130 | 137 | self._attempt_to_create_cursor_if_not_provided = attempt_to_create_cursor_if_not_provided
|
| 138 | + self._synced_some_data = False |
131 | 139 |
|
132 | 140 | @property
|
133 | 141 | def cursor_field(self) -> CursorField:
|
@@ -168,8 +176,8 @@ def close_partition(self, partition: Partition) -> None:
|
168 | 176 | with self._lock:
|
169 | 177 | self._semaphore_per_partition[partition_key].acquire()
|
170 | 178 | if not self._use_global_cursor:
|
171 |
| - self._cursor_per_partition[partition_key].close_partition(partition=partition) |
172 | 179 | cursor = self._cursor_per_partition[partition_key]
|
| 180 | + cursor.close_partition(partition=partition) |
173 | 181 | if (
|
174 | 182 | partition_key in self._partitions_done_generating_stream_slices
|
175 | 183 | and self._semaphore_per_partition[partition_key]._value == 0
|
@@ -213,8 +221,10 @@ def ensure_at_least_one_state_emitted(self) -> None:
|
213 | 221 | if not any(
|
214 | 222 | semaphore_item[1]._value for semaphore_item in self._semaphore_per_partition.items()
|
215 | 223 | ):
|
216 |
| - self._global_cursor = self._new_global_cursor |
217 |
| - self._lookback_window = self._timer.finish() |
| 224 | + if self._synced_some_data: |
| 225 | + # we only update those if we actually synced some data |
| 226 | + self._global_cursor = self._new_global_cursor |
| 227 | + self._lookback_window = self._timer.finish() |
218 | 228 | self._parent_state = self._partition_router.get_stream_state()
|
219 | 229 | self._emit_state_message(throttle=False)
|
220 | 230 |
|
@@ -422,9 +432,6 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
|
422 | 432 | if stream_state.get("parent_state"):
|
423 | 433 | self._parent_state = stream_state["parent_state"]
|
424 | 434 |
|
425 |
| - # Set parent state for partition routers based on parent streams |
426 |
| - self._partition_router.set_initial_state(stream_state) |
427 |
| - |
428 | 435 | def _set_global_state(self, stream_state: Mapping[str, Any]) -> None:
|
429 | 436 | """
|
430 | 437 | Initializes the global cursor state from the provided stream state.
|
@@ -458,6 +465,7 @@ def observe(self, record: Record) -> None:
|
458 | 465 | except ValueError:
|
459 | 466 | return
|
460 | 467 |
|
| 468 | + self._synced_some_data = True |
461 | 469 | record_cursor = self._connector_state_converter.output_format(
|
462 | 470 | self._connector_state_converter.parse_value(record_cursor_value)
|
463 | 471 | )
|
@@ -541,3 +549,45 @@ def _get_cursor(self, record: Record) -> ConcurrentCursor:
|
541 | 549 |
|
542 | 550 | def limit_reached(self) -> bool:
|
543 | 551 | return self._number_of_partitions > self.SWITCH_TO_GLOBAL_LIMIT
|
| 552 | + |
| 553 | + @staticmethod |
| 554 | + def get_parent_state( |
| 555 | + stream_state: Optional[StreamState], parent_stream_name: str |
| 556 | + ) -> Optional[AirbyteStateMessage]: |
| 557 | + if not stream_state: |
| 558 | + return None |
| 559 | + |
| 560 | + if "parent_state" not in stream_state: |
| 561 | + logger.warning( |
| 562 | + f"Trying to get_parent_state for stream `{parent_stream_name}` when there are not parent state in the state" |
| 563 | + ) |
| 564 | + return None |
| 565 | + elif parent_stream_name not in stream_state["parent_state"]: |
| 566 | + logger.info( |
| 567 | + f"Could not find parent state for stream `{parent_stream_name}`. On parents available are {list(stream_state['parent_state'].keys())}" |
| 568 | + ) |
| 569 | + return None |
| 570 | + |
| 571 | + return AirbyteStateMessage( |
| 572 | + type=AirbyteStateType.STREAM, |
| 573 | + stream=AirbyteStreamState( |
| 574 | + stream_descriptor=StreamDescriptor(parent_stream_name, None), |
| 575 | + stream_state=AirbyteStateBlob(stream_state["parent_state"][parent_stream_name]), |
| 576 | + ), |
| 577 | + ) |
| 578 | + |
| 579 | + @staticmethod |
| 580 | + def get_global_state( |
| 581 | + stream_state: Optional[StreamState], parent_stream_name: str |
| 582 | + ) -> Optional[AirbyteStateMessage]: |
| 583 | + return ( |
| 584 | + AirbyteStateMessage( |
| 585 | + type=AirbyteStateType.STREAM, |
| 586 | + stream=AirbyteStreamState( |
| 587 | + stream_descriptor=StreamDescriptor(parent_stream_name, None), |
| 588 | + stream_state=AirbyteStateBlob(stream_state["state"]), |
| 589 | + ), |
| 590 | + ) |
| 591 | + if stream_state and "state" in stream_state |
| 592 | + else None |
| 593 | + ) |
0 commit comments