Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions xarray_beam/_src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import operator
import os.path
import tempfile
from typing import Any, Callable
from typing import Any, Callable, Literal

import apache_beam as beam
import xarray
Expand Down Expand Up @@ -175,10 +175,44 @@ def from_zarr(cls, path: str, split_vars: bool = False) -> Dataset:
result.ptransform = _get_label('from_zarr') >> result.ptransform
return result

def to_zarr(self, path: str) -> beam.PTransform:
def _check_shards_or_chunks(
self,
zarr_chunks: Mapping[str, int],
chunks_name: Literal['shards', 'chunks'],
) -> None:
if any(self.chunks[k] % zarr_chunks[k] for k in self.chunks):
raise ValueError(
f'cannot write a dataset with chunks {self.chunks} to Zarr with '
f'{chunks_name} {zarr_chunks}, which do not divide evenly into '
f'{chunks_name}'
)

def to_zarr(
self,
path: str,
zarr_chunks: Mapping[str, int] | None = None,
zarr_shards: Mapping[str, int] | None = None,
zarr_format: int | None = None,
) -> beam.PTransform:
"""Write to a Zarr file."""
if zarr_chunks is None:
if zarr_shards is not None:
raise ValueError('cannot supply zarr_shards without zarr_chunks')
zarr_chunks = {}

zarr_chunks = {**self.chunks, **zarr_chunks}
if zarr_shards is not None:
zarr_shards = {**self.chunks, **zarr_shards}
self._check_shards_or_chunks(zarr_shards, 'shards')
else:
self._check_shards_or_chunks(zarr_chunks, 'chunks')

return self.ptransform | _get_label('to_zarr') >> zarr.ChunksToZarr(
path, self.template, self.chunks
path,
self.template,
zarr_chunks=zarr_chunks,
zarr_shards=zarr_shards,
zarr_format=zarr_format,
)

def collect_with_direct_runner(self) -> xarray.Dataset:
Expand Down
70 changes: 66 additions & 4 deletions xarray_beam/_src/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,77 @@ def test_from_zarr(self, split_vars):

def test_to_zarr(self):
temp_dir = self.create_tempdir().full_path
ds = xarray.Dataset({'foo': ('x', np.arange(10))})
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 5})
to_zarr = beam_ds.to_zarr(temp_dir)
ds = xarray.Dataset({'foo': ('x', np.arange(12))})
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 6})

with self.subTest('same_chunks'):
to_zarr = beam_ds.to_zarr(temp_dir)
self.assertRegex(to_zarr.label, r'^from_xarray_\d+|to_zarr_\d+$')
with beam.Pipeline() as p:
p |= to_zarr
opened, chunks = xbeam.open_zarr(temp_dir)
xarray.testing.assert_identical(ds, opened)
self.assertEqual(chunks, {'x': 6})

with self.subTest('smaller_chunks'):
temp_dir = self.create_tempdir().full_path
with beam.Pipeline() as p:
p |= beam_ds.to_zarr(temp_dir, zarr_chunks={'x': 3})
opened, chunks = xbeam.open_zarr(temp_dir)
xarray.testing.assert_identical(ds, opened)
self.assertEqual(chunks, {'x': 3})

with self.subTest('larger_chunks'):
with self.assertRaisesWithLiteralMatch(
ValueError,
"cannot write a dataset with chunks {'x': 6} to Zarr with chunks "
"{'x': 9}, which do not divide evenly into chunks",
):
beam_ds.to_zarr(temp_dir, zarr_chunks={'x': 9})

with self.subTest('shards_without_chunks'):
with self.assertRaisesWithLiteralMatch(
ValueError, 'cannot supply zarr_shards without zarr_chunks'
):
beam_ds.to_zarr(temp_dir, zarr_shards={'x': -1})

def test_to_zarr_shards(self):
temp_dir = self.create_tempdir().full_path
ds = xarray.Dataset({'foo': ('x', np.arange(12))})
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 6})

with self.subTest('same_shards_as_chunks'):
with beam.Pipeline() as p:
p |= beam_ds.to_zarr(
temp_dir, zarr_chunks={'x': 3}, zarr_shards={'x': 6}, zarr_format=3
)
opened, chunks = xbeam.open_zarr(temp_dir)
xarray.testing.assert_identical(ds, opened)
self.assertEqual(chunks, {'x': 3})
self.assertEqual(opened['foo'].encoding['shards'], (6,))

with self.subTest('larger_shards'):
with self.assertRaisesWithLiteralMatch(
ValueError,
"cannot write a dataset with chunks {'x': 6} to Zarr with shards "
"{'x': 9}, which do not divide evenly into shards",
):
beam_ds.to_zarr(
temp_dir, zarr_chunks={'x': 3}, zarr_shards={'x': 9}, zarr_format=3
)

def test_to_zarr_default_chunks(self):
temp_dir = self.create_tempdir().full_path
ds = xarray.Dataset({'foo': (('x', 'y'), np.arange(20).reshape(10, 2))})
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 4})
to_zarr = beam_ds.to_zarr(temp_dir, zarr_chunks={'x': 2})

self.assertRegex(to_zarr.label, r'^from_xarray_\d+|to_zarr_\d+$')
with beam.Pipeline() as p:
p |= to_zarr
opened = xarray.open_zarr(temp_dir).compute()
opened, chunks = xbeam.open_zarr(temp_dir)
xarray.testing.assert_identical(ds, opened)
self.assertEqual(chunks, {'x': 2, 'y': 2})

@parameterized.named_parameters(
dict(testcase_name='getitem', call=lambda x: x[['foo']]),
Expand Down