Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -238,22 +238,24 @@ def undo(self, data: da.Array) -> da.Array:
raise RuntimeError("Shape not set, therefore cannot undo")

def _unflatten(data, shape):
while len(data.shape) > len(shape):
shape = (data[-len(shape)], *shape)
# while len(data.shape) > len(shape):
# shape = (data[-len(shape)], *shape)
return data.reshape(shape)

if self.flatten_dims is None:
raise RuntimeError("`flatten_dims` was not set, and this set hasn't been used. Cannot Unflatten.")

data_shape = data.shape
parsed_shape = data_shape[: -1 * min(1, (self.flatten_dims - 1))] if len(data_shape) > 1 else data_shape
# parsed_shape = data_shape[: -1 * min(1, (self.flatten_dims - 1))] if len(data_shape) > 1 else data_shape
parsed_shape = data_shape[:-1] if len(data_shape) > 1 else []
attempts = [
(*parsed_shape, *self._unflattenshape),
]

if self.shape_attempt:
shape_attempt = self._configure_shape_attempt()
if shape_attempt:
# if self.shape_attempt is truthy then shape_attempt is always truthy.
if shape_attempt: # pragma: no cover
attempts.append((*parsed_shape, *shape_attempt[-1 * self.flatten_dims :])) # type: ignore

for attemp in attempts:
Expand Down Expand Up @@ -330,7 +332,7 @@ def __init__(
"""
super().__init__(
split_tuples=False,
recognised_types=dict(apply=(da.Array,), undo=(da.Array, np.ndarray)),
recognised_types=dict(apply=(da.Array, tuple), undo=(da.Array, np.ndarray, tuple)),
)
self.record_initialisation()

Expand Down
82 changes: 82 additions & 0 deletions packages/pipeline/tests/operations/dask/test_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright Commonwealth of Australia, Bureau of Meteorology 2026.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock
import pytest

from pyearthtools.pipeline.operations.dask.dask import DaskOperation
from pyearthtools.pipeline.operation import Operation
import numpy as np
import dask.array as da


class FakeDaskOperation(DaskOperation):
_numpy_counterpart = "FakeNumpyOperation"

def apply_func(self, sample):
return "dask_apply"

def undo_func(self, sample):
return "dask_undo"


class FakeNumpyOperation(Operation):
def apply_func(self, sample):
return "numpy_apply"

def undo_func(self, sample):
return "numpy_undo"


def _augmented_dynamic_import(*args):
return FakeNumpyOperation


@pytest.mark.parametrize(
("op", "arr_type", "expected", "dispatched"),
[
("apply", np.array, "numpy_apply", True),
("undo", np.array, "numpy_undo", True),
("apply", da.array, "dask_apply", False),
("undo", da.array, "dask_undo", False),
],
)
def test_dask_operation_numpy_dispatch(op, arr_type, expected, dispatched):
sample = arr_type(1)
dask_op = FakeDaskOperation()

# patch dynamic_import to ensure the fake numpy op is used
with mock.patch("pyearthtools.utils.dynamic_import", side_effect=_augmented_dynamic_import) as mock_dynamic_import:

# check correct dispatch depending on sample type
# when sample is np.ndarray, dask_op should dispatch to equivalent numpy op
# when sample is anything else, it should use the inbuilt op
assert getattr(dask_op, op)(sample) == expected
assert mock_dynamic_import.called == dispatched
# when sample is np.ndarray, check that np.ndarray is added to recognised_types for the op
if dispatched:
assert np.ndarray in dask_op.recognised_types[op]
else:
assert dask_op.recognised_types == {}

# run op again, to check that np.ndarry only appears once in recognised_types
assert getattr(dask_op, op)(sample) == expected
if dispatched:
assert dask_op.recognised_types[op].count(np.ndarray) == 1
else:
assert dask_op.recognised_types == {}

# turn off op in _operation to check that input sample is returned
dask_op._operation[op] = False
assert getattr(dask_op, op)(sample) == sample
91 changes: 91 additions & 0 deletions packages/pipeline/tests/operations/dask/test_dask_augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright Commonwealth of Australia, Bureau of Meteorology 2026.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pyearthtools.pipeline.operations.dask import augment

import numpy as np
import dask.array as da
import pytest


@pytest.mark.parametrize(
# The result depends on the random seed. This one has been manually
# checked to produce a certain number of rotations the first time.
"seed, rotations",
[
(42, 0),
(1, 1),
(4, 2),
(2, 3),
],
)
def test_Rotate(seed, rotations):

original = da.from_array(np.array([[1, 2], [4, 3]]))

match rotations:
case 0:
expected = np.array([[1, 2], [4, 3]])
case 1:
expected = np.array([[4, 1], [3, 2]])
case 2:
expected = np.array([[3, 4], [2, 1]])
case 3:
expected = np.array([[2, 3], [1, 4]])

rotate = augment.Rotate(seed=seed, axis=(1, 0))

result = rotate.apply_func(original)
assert isinstance(result, da.Array)
assert np.array_equal(result.compute(), expected)


def test_Rotate_axis_must_be_tuple():
with pytest.raises(TypeError):
augment.Rotate(axis=0)


@pytest.mark.parametrize(
"seed, should_flip",
[
(0, True),
(1, False),
],
)
def test_Flip(seed, should_flip):

original = da.from_array(np.array([[1, 2], [4, 3]]))

flipped = np.array([[3, 4], [2, 1]])

# The result depends on the random seed. This one has been manually checked
# to produce a single rotation the first time.
expected = flipped if should_flip else original.compute()
flip = augment.Flip(seed=seed, axis=(1, 0))

result = flip.apply_func(original)
assert np.array_equal(result.compute(), expected)


def test_FlipAndRotate():

original = da.from_array(np.array([[1, 2], [4, 3]]))

flip_and_rotate = augment.FlipAndRotate()

result = flip_and_rotate.apply_func(original)
# Don't worry about the number of flips and rotations, just check the
# shape and type returned
assert isinstance(result, da.Array)
assert result.shape == (2, 2)
Loading
Loading