From 11167932f249e85422fc130e73382ec1b5e69a20 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 29 Sep 2025 12:45:30 -0700 Subject: [PATCH] Add support for zarr_shards to xbeam.Dataset PiperOrigin-RevId: 812903711 --- xarray_beam/_src/dataset.py | 40 ++++++++++++++++-- xarray_beam/_src/dataset_test.py | 70 ++++++++++++++++++++++++++++++-- 2 files changed, 103 insertions(+), 7 deletions(-) diff --git a/xarray_beam/_src/dataset.py b/xarray_beam/_src/dataset.py index c583327..eabac39 100644 --- a/xarray_beam/_src/dataset.py +++ b/xarray_beam/_src/dataset.py @@ -36,7 +36,7 @@ import operator import os.path import tempfile -from typing import Any, Callable +from typing import Any, Callable, Literal import apache_beam as beam import xarray @@ -175,10 +175,44 @@ def from_zarr(cls, path: str, split_vars: bool = False) -> Dataset: result.ptransform = _get_label('from_zarr') >> result.ptransform return result - def to_zarr(self, path: str) -> beam.PTransform: + def _check_shards_or_chunks( + self, + zarr_chunks: Mapping[str, int], + chunks_name: Literal['shards', 'chunks'], + ) -> None: + if any(self.chunks[k] % zarr_chunks[k] for k in self.chunks): + raise ValueError( + f'cannot write a dataset with chunks {self.chunks} to Zarr with ' + f'{chunks_name} {zarr_chunks}, which do not divide evenly into ' + f'{chunks_name}' + ) + + def to_zarr( + self, + path: str, + zarr_chunks: Mapping[str, int] | None = None, + zarr_shards: Mapping[str, int] | None = None, + zarr_format: int | None = None, + ) -> beam.PTransform: """Write to a Zarr file.""" + if zarr_chunks is None: + if zarr_shards is not None: + raise ValueError('cannot supply zarr_shards without zarr_chunks') + zarr_chunks = {} + + zarr_chunks = {**self.chunks, **zarr_chunks} + if zarr_shards is not None: + zarr_shards = {**self.chunks, **zarr_shards} + self._check_shards_or_chunks(zarr_shards, 'shards') + else: + self._check_shards_or_chunks(zarr_chunks, 'chunks') + return self.ptransform | _get_label('to_zarr') >> zarr.ChunksToZarr( - path, self.template, self.chunks + path, + self.template, + zarr_chunks=zarr_chunks, + zarr_shards=zarr_shards, + zarr_format=zarr_format, ) def collect_with_direct_runner(self) -> xarray.Dataset: diff --git a/xarray_beam/_src/dataset_test.py b/xarray_beam/_src/dataset_test.py index ac4b367..1569ee7 100644 --- a/xarray_beam/_src/dataset_test.py +++ b/xarray_beam/_src/dataset_test.py @@ -81,15 +81,77 @@ def test_from_zarr(self, split_vars): def test_to_zarr(self): temp_dir = self.create_tempdir().full_path - ds = xarray.Dataset({'foo': ('x', np.arange(10))}) - beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5}) - to_zarr = beam_ds.to_zarr(temp_dir) + ds = xarray.Dataset({'foo': ('x', np.arange(12))}) + beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 6}) + + with self.subTest('same_chunks'): + to_zarr = beam_ds.to_zarr(temp_dir) + self.assertRegex(to_zarr.label, r'^from_xarray_\d+|to_zarr_\d+$') + with beam.Pipeline() as p: + p |= to_zarr + opened, chunks = xbeam.open_zarr(temp_dir) + xarray.testing.assert_identical(ds, opened) + self.assertEqual(chunks, {'x': 6}) + + with self.subTest('smaller_chunks'): + temp_dir = self.create_tempdir().full_path + with beam.Pipeline() as p: + p |= beam_ds.to_zarr(temp_dir, zarr_chunks={'x': 3}) + opened, chunks = xbeam.open_zarr(temp_dir) + xarray.testing.assert_identical(ds, opened) + self.assertEqual(chunks, {'x': 3}) + + with self.subTest('larger_chunks'): + with self.assertRaisesWithLiteralMatch( + ValueError, + "cannot write a dataset with chunks {'x': 6} to Zarr with chunks " + "{'x': 9}, which do not divide evenly into chunks", + ): + beam_ds.to_zarr(temp_dir, zarr_chunks={'x': 9}) + + with self.subTest('shards_without_chunks'): + with self.assertRaisesWithLiteralMatch( + ValueError, 'cannot supply zarr_shards without zarr_chunks' + ): + beam_ds.to_zarr(temp_dir, zarr_shards={'x': -1}) + + def test_to_zarr_shards(self): + temp_dir = self.create_tempdir().full_path + ds = xarray.Dataset({'foo': ('x', np.arange(12))}) + beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 6}) + + with self.subTest('same_shards_as_chunks'): + with beam.Pipeline() as p: + p |= beam_ds.to_zarr( + temp_dir, zarr_chunks={'x': 3}, zarr_shards={'x': 6}, zarr_format=3 + ) + opened, chunks = xbeam.open_zarr(temp_dir) + xarray.testing.assert_identical(ds, opened) + self.assertEqual(chunks, {'x': 3}) + self.assertEqual(opened['foo'].encoding['shards'], (6,)) + + with self.subTest('larger_shards'): + with self.assertRaisesWithLiteralMatch( + ValueError, + "cannot write a dataset with chunks {'x': 6} to Zarr with shards " + "{'x': 9}, which do not divide evenly into shards", + ): + beam_ds.to_zarr( + temp_dir, zarr_chunks={'x': 3}, zarr_shards={'x': 9}, zarr_format=3 + ) + + def test_to_zarr_default_chunks(self): + temp_dir = self.create_tempdir().full_path + ds = xarray.Dataset({'foo': (('x', 'y'), np.arange(20).reshape(10, 2))}) + beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 4}) + to_zarr = beam_ds.to_zarr(temp_dir, zarr_chunks={'x': 2}) self.assertRegex(to_zarr.label, r'^from_xarray_\d+|to_zarr_\d+$') with beam.Pipeline() as p: p |= to_zarr - opened = xarray.open_zarr(temp_dir).compute() + opened, chunks = xbeam.open_zarr(temp_dir) xarray.testing.assert_identical(ds, opened) + self.assertEqual(chunks, {'x': 2, 'y': 2}) @parameterized.named_parameters( dict(testcase_name='getitem', call=lambda x: x[['foo']]),