From af6a4ca3d388599f776010919aca21edaa05feac Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Thu, 26 Mar 2026 17:46:10 -0700 Subject: [PATCH 1/5] feat: Add MLX test_utils.py entries from https://github.com/DBraun/sequence-layers/commit/80daa69bcb5a5580ff9fb73d13e416a1813b1462 --- sequence_layers/mlx/test_utils_mlx.py | 174 ++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 sequence_layers/mlx/test_utils_mlx.py diff --git a/sequence_layers/mlx/test_utils_mlx.py b/sequence_layers/mlx/test_utils_mlx.py new file mode 100644 index 0000000..5f8dd09 --- /dev/null +++ b/sequence_layers/mlx/test_utils_mlx.py @@ -0,0 +1,174 @@ +"""Test utilities for MLX sequence layers (legacy from branch).""" + +import mlx.core as mx +import numpy as np + +from sequence_layers.mlx import basic_types as bt + +Sequence = bt.Sequence +MaskedSequence = bt.MaskedSequence +ShapeDType = bt.ShapeDType + + +def random_sequence( + batch: int, + time: int, + channels: int | tuple[int, ...], + *, + dtype=mx.float32, + mask: mx.array | None = None, + masked: bool = True, +) -> Sequence: + """Create a random Sequence for testing. + + Args: + batch: Batch size. + time: Sequence length. + channels: Channel size (int) or channel shape (tuple). + dtype: Values dtype. + mask: Optional explicit mask. If None, all-valid mask is used. + masked: If True, returns a MaskedSequence. If False, a Sequence. + + Returns: + A random Sequence or MaskedSequence. + """ + if isinstance(channels, int): + channels = (channels,) + shape = (batch, time) + channels + values = mx.random.normal(shape=shape).astype(dtype) + if mask is None: + mask = mx.ones((batch, time), dtype=mx.bool_) + if masked: + return MaskedSequence(values, mask) + return Sequence(values, mask) + + +def step_by_step( + layer, + x: Sequence, + *, + block_size: int = 1, + constants=None, + stream_constants=None, +) -> tuple[Sequence, object]: + """Run a layer step-by-step and concatenate outputs. + + Args: + layer: A SequenceLayer with supports_step. + x: Input sequence [batch, time, ...]. + block_size: Number of timesteps per step. + constants: Optional constants dict (static, passed as-is each step). + stream_constants: Optional dict of source_name -> Sequence. These are + sliced at the same block_size as input for each step, merging into + the constants dict. Use this for streaming cross-attention sources. + + Returns: + (output_sequence, final_state) + """ + batch = x.shape[0] + time = x.shape[1] + spec = x.channel_spec + + # Build initial constants with full stream sources for get_initial_state. + init_constants = dict(constants) if constants else {} + if stream_constants: + init_constants.update(stream_constants) + + state = layer.get_initial_state(batch, spec, constants=init_constants or None) + + outputs_values = [] + outputs_masks = [] + + for t in range(0, time, block_size): + x_block = Sequence( + x.values[:, t : t + block_size], + x.mask[:, t : t + block_size], + ) + + # Build per-step constants with sliced stream sources. + step_constants = dict(constants) if constants else {} + if stream_constants: + for name, seq in stream_constants.items(): + step_constants[name] = Sequence( + seq.values[:, t : t + block_size], + seq.mask[:, t : t + block_size], + ) + + y_block, state = layer.step( + x_block, + state, + constants=step_constants or None, + ) + outputs_values.append(y_block.values) + outputs_masks.append(y_block.mask) + + y_values = mx.concatenate(outputs_values, axis=1) + y_mask = mx.concatenate(outputs_masks, axis=1) + return Sequence(y_values, y_mask), state + + +def verify_contract( + test_case, + layer, + input_shape, + *, + batch_size: int = 2, + time: int = 8, + dtype=mx.float32, + constants=None, + atol: float = 1e-5, + rtol: float = 1e-5, + test_step: bool = True, +): + """Verify that a layer's layer() and step() outputs are consistent. + + Checks: + 1. layer() runs without error and produces correct output shape. + 2. step() runs without error and produces correct output shape. + 3. layer() and step() produce approximately equal outputs. + + Args: + test_case: An absltest.TestCase (or similar) with assertion methods. + layer: The SequenceLayer to test. + input_shape: Channel shape (tuple), e.g. (16,). + batch_size: Batch size for test inputs. + time: Sequence length for test inputs. + dtype: Input dtype. + constants: Optional constants dict. + atol: Absolute tolerance for output comparison. + rtol: Relative tolerance for output comparison. + test_step: Whether to test step() and compare with layer(). + """ + x = random_sequence(batch_size, time, input_shape, dtype=dtype) + + # Test layer(). + y_layer = layer.layer(x, constants=constants) + + # Check output shape. + expected_shape = layer.get_output_shape(input_shape, constants=constants) + test_case.assertEqual(y_layer.channel_shape, expected_shape) + + # Check output dtype. + expected_dtype = layer.get_output_dtype(dtype, constants=constants) + test_case.assertEqual(y_layer.dtype, expected_dtype) + + if not test_step or not layer.supports_step: + return + + # Test step(). + block_size = layer.block_size + y_step, _ = step_by_step(layer, x, block_size=block_size, constants=constants) + + # Check shapes match. + test_case.assertEqual(y_step.shape, y_layer.shape) + + # Check values match. + y_layer_np = np.array(y_layer.values) + y_step_np = np.array(y_step.values) + np.testing.assert_allclose( + y_step_np, + y_layer_np, + atol=atol, + rtol=rtol, + err_msg=f'{layer.__class__.__name__}: step() and layer() outputs differ', + ) From 6d25087e65b5246daa61c63e95022a663924deae Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Thu, 9 Apr 2026 21:18:25 -0700 Subject: [PATCH 2/5] refactor(test_utils): Abstract into spec and implementations. --- sequence_layers/jax/__init__.py | 9 +- sequence_layers/jax/backend.py | 4 + sequence_layers/jax/backend_test.py | 14 + sequence_layers/jax/test_utils.py | 116 ++----- sequence_layers/jax/test_utils_test.py | 170 ++++----- sequence_layers/jax/types_test.py | 21 +- sequence_layers/mlx/__init__.py | 9 +- sequence_layers/mlx/backend.py | 5 + sequence_layers/mlx/backend_test.py | 14 + sequence_layers/mlx/test_utils.py | 211 +++++++++++- sequence_layers/mlx/test_utils_mlx.py | 174 ---------- sequence_layers/mlx/test_utils_test.py | 24 ++ sequence_layers/mlx/types.py | 12 +- sequence_layers/mlx/types_test.py | 4 +- sequence_layers/specs/__init__.py | 15 +- sequence_layers/specs/backend.py | 4 + sequence_layers/specs/backend_behaviors.py | 16 + sequence_layers/specs/test_utils.py | 226 +++++++++++- sequence_layers/specs/test_utils_behaviors.py | 324 ++++++++++++++++++ sequence_layers/specs/types_behaviors.py | 66 +++- 20 files changed, 1026 insertions(+), 412 deletions(-) create mode 100644 sequence_layers/jax/backend_test.py create mode 100644 sequence_layers/mlx/backend_test.py delete mode 100644 sequence_layers/mlx/test_utils_mlx.py create mode 100644 sequence_layers/mlx/test_utils_test.py create mode 100644 sequence_layers/specs/backend_behaviors.py create mode 100644 sequence_layers/specs/test_utils_behaviors.py diff --git a/sequence_layers/jax/__init__.py b/sequence_layers/jax/__init__.py index dcbb3a0..fe862a9 100644 --- a/sequence_layers/jax/__init__.py +++ b/sequence_layers/jax/__init__.py @@ -13,6 +13,11 @@ # limitations under the License. """Sequence layers in JAX.""" +# (re-export the names for typechecking) +from . import backend as backend +from . import types as types +from . import test_utils as test_utils +from .test_utils import SequenceLayerTest # pylint: disable=wildcard-import from sequence_layers.jax.attention import * from sequence_layers.jax.combinators import * @@ -28,6 +33,4 @@ from sequence_layers.jax.time_varying import * from sequence_layers.jax.types import * -# (re-export the names for typechecking) -from . import backend as backend -from . import types as types + diff --git a/sequence_layers/jax/backend.py b/sequence_layers/jax/backend.py index f29d4e1..320f495 100644 --- a/sequence_layers/jax/backend.py +++ b/sequence_layers/jax/backend.py @@ -13,6 +13,7 @@ class BackendWrapper(spec.xp): bool_ = jnp.bool_ int32 = jnp.int32 + float32 = jnp.float32 @override def array(self, a, dtype=None) -> types_spec.Array: @@ -22,5 +23,8 @@ def array(self, a, dtype=None) -> types_spec.Array: def zeros(self, shape, dtype=None) -> types_spec.Array: return jnp.zeros(shape, dtype=dtype) + def concatenate(self, arrays, axis=0) -> types_spec.Array: + return jnp.concatenate(arrays, axis=axis) + xp: spec.xp = BackendWrapper() diff --git a/sequence_layers/jax/backend_test.py b/sequence_layers/jax/backend_test.py new file mode 100644 index 0000000..b627230 --- /dev/null +++ b/sequence_layers/jax/backend_test.py @@ -0,0 +1,14 @@ +"""Tests for JAX backend utilities.""" + +from absl.testing import absltest + +from sequence_layers.jax import test_utils +from sequence_layers.specs import backend_behaviors as spec + + +class ModuleSpecTest(test_utils.SequenceLayerTest, spec.ModuleSpecTest): + pass + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/jax/test_utils.py b/sequence_layers/jax/test_utils.py index 13add8f..274e20b 100644 --- a/sequence_layers/jax/test_utils.py +++ b/sequence_layers/jax/test_utils.py @@ -275,33 +275,7 @@ def zip_longest( zipped argument internally sorted (target, source). If either input sequence was longer, the last element of the shorter input sequence is repeated. """ - - results = [] - prev_source, prev_target = None, None - for source, target in itertools.zip_longest(sources, targets): - # If either runs out ahead-of-time, we repeat the final non-None element. - # (This is safest as we cannot inspect the function's defaults.) - if source is None: - source = prev_source - elif target is None: - target = prev_target - - if isinstance(target, Mapping): - assert isinstance(source, Mapping) - results.append({**target, **source}) - elif isinstance(sources, Iterable): - # target is a non-mapping iterable, like tuple or list. - if isinstance(source, Mapping): - # To match the target, we replace the source with its unlabeled values. - source = source.values() - results.append((*target, *source)) - prev_source, prev_target = source, target - else: - raise NotImplementedError( - f'Targets of type {type(target)=} are unsupported.' - ) - - return results + return spec.zip_longest(targets, sources) def named_product( @@ -326,62 +300,8 @@ def named_product( `{first_item_name}_{second_item_name}`. If both iterators' items are mappings, the product's items are mappings; otherwise they are ordered tuples. - - For example, if `first` is - `[{**foo, 'testcase_name': 'foo'}, {**bar, 'testcase_name': 'bar'}]` and - `second` is `[['baz', *baz], ['qux', *qux]]`, the items will be - `('foo_baz', *foo.values(), *baz), ('foo_qux', *foo.values(), *qux), ...` - - Raises: - ValueError: A testcase_name is missing; either an iterator item is empty, or - one is a mapping without a `testcase_name` key. """ - - results = [] - - for p1, p2 in itertools.product(first, second): - - for source, parameters in enumerate([p1, p2]): - if isinstance(parameters, Mapping): - if 'testcase_name' not in parameters: - raise ValueError( - f'Mapping {parameters} from iterable #{source+1} does not have' - ' key `testcase_name`.' - ) - elif not parameters: - raise ValueError( - f'An sequence from iterable #{source+1} is empty; the first entry' - ' is expected to be a testcase name.' - ) - - # When both are mappings, we merge by key: - if isinstance(p1, Mapping) and isinstance(p2, Mapping): - testcase_name = f'{p1["testcase_name"]}_{p2["testcase_name"]}' - p1 = {k: v for k, v in p1.items() if k != 'testcase_name'} - p2 = {k: v for k, v in p2.items() if k != 'testcase_name'} - results.append({**p1, **p2, 'testcase_name': testcase_name}) - - # Else, we return an ordered tuple based on each parameter set's order: - else: - - if isinstance(p1, Mapping): - p1_name = p1['testcase_name'] - p1 = tuple(v for k, v in p1.items() if k != 'testcase_name') - else: - p1_name = p1[0] - p1 = p1[1:] - - if isinstance(p2, Mapping): - p2_name = p2['testcase_name'] - p2 = tuple(v for k, v in p2.items() if k != 'testcase_name') - else: - p2_name = p2[0] - p2 = p2[1:] - - testcase_name = f'{p1_name}_{p2_name}' - results.append((testcase_name, *p1, *p2)) - - return parameterized.named_parameters(*results) + return spec.named_product(first, second) def get_grad_tols( @@ -786,6 +706,7 @@ class SequenceLayerTest(spec.SequenceLayerTest): sl = sl + @override def setUp(self): super().setUp() # To avoid flakes, fix random seeds. @@ -1204,6 +1125,30 @@ def _pad(x: types.Sequence, pad_back: int) -> types.Sequence: self.assertEqual(receptive_field, expected_receptive_field) return y_layer + @override + def random_sequence( + self, + *dims: int, + dtype=jnp.float32, + random_mask: bool = False, + random_lengths: bool | None = None, + low: int | None = 0, + high: int | None = 10, + low_length: int = 0, + high_length: int | None = None, + ) -> types.Sequence: + return random_sequence( + *dims, + dtype=dtype, + random_mask=random_mask, + random_lengths=random_lengths, + low=low, + high=high, + low_length=low_length, + high_length=high_length, + ) + + @override def assertSequencesClose( # pylint: disable=invalid-name self, a: types.Sequence, @@ -1297,11 +1242,13 @@ class Config(types.SequenceLayerConfig): expected_constant: str = 'test' name: str | None = None + @override def make(self) -> 'AssertConstantsLayer': return AssertConstantsLayer(self, name=self.name) config: Config + @override def get_initial_state( self, batch_size: int, @@ -1316,6 +1263,7 @@ def get_initial_state( batch_size, input_spec, training=training, constants=constants ) + @override def get_output_shape( self, input_shape: types.ShapeLike, @@ -1326,6 +1274,7 @@ def get_output_shape( raise ValueError(f'{self.config.expected_constant=} not present') return super().get_output_shape(input_shape, constants=constants) + @override def layer( self, x: types.Sequence, @@ -1346,15 +1295,18 @@ class NonSteppableLayer(types.PreservesType, types.StatelessPointwise): class Config(types.SequenceLayerConfig): name: str | None = None + @override def make(self) -> 'NonSteppableLayer': return NonSteppableLayer(self, name=self.name) config: Config @property + @override def supports_step(self): return False + @override def layer( self, x: types.Sequence, diff --git a/sequence_layers/jax/test_utils_test.py b/sequence_layers/jax/test_utils_test.py index 5c113e3..8bc6fd8 100644 --- a/sequence_layers/jax/test_utils_test.py +++ b/sequence_layers/jax/test_utils_test.py @@ -13,17 +13,47 @@ # limitations under the License. """Tests for the test utilities.""" -from unittest import mock from absl.testing import parameterized -import numpy as np +import jax +import jax.numpy as jnp +import sequence_layers.jax as sl from sequence_layers.jax import test_utils +from sequence_layers.specs import test_utils_behaviors as spec + + +class ModuleSpecTest(test_utils.SequenceLayerTest, spec.ModuleSpecTest): + pass + + +class VerifyContractTest(test_utils.SequenceLayerTest, spec.VerifyContractTest): + + def get_dummy_layer(self, mismatch: bool): + l = super().get_dummy_layer(mismatch) + key = jax.random.PRNGKey(1234) + x = test_utils.random_sequence(2, 5, 10) + l = self.init_and_bind_layer(key, l, x) + return l + + def test_verify_contract_with_jax_flags(self): + """Tests that disabling optional JAX features (gradients, batching) doesn't crash. + + Default paths (with these flags as True) are tested in all other tests. + """ + layer = self.get_dummy_layer(mismatch=False) + x = sl.Sequence( + jnp.ones((2, 5, 10)), + jnp.ones((2, 5), dtype=bool), + ) + self.verify_contract( + layer, x, training=False, test_gradients=False, test_batching=False + ) class StandardDtypeConfigsTest(test_utils.SequenceLayerTest): @parameterized.parameters( ( - dict(), + {}, { 'p-fp32_i-fp32_c-None', # default 'p-bf16_i-bf16_c-bf16', # praxis @@ -33,7 +63,7 @@ class StandardDtypeConfigsTest(test_utils.SequenceLayerTest): }, ), ( - dict(param=True, compute=True), + {'param': True, 'compute': True}, { 'p-fp32_c-None', # default 'p-bf16_c-bf16', # praxis @@ -43,7 +73,7 @@ class StandardDtypeConfigsTest(test_utils.SequenceLayerTest): }, ), ( - dict(praxis_only=True), + {'praxis_only': True}, { 'p-fp32_i-fp32_c-None', # default 'p-bf16_i-bf16_c-bf16', # praxis @@ -58,113 +88,55 @@ def test_standard_dtype_configs_returns_names(self, kwargs, expected): self.assertEqual(expected, names) -class NamedProductTest(test_utils.SequenceLayerTest): +class NamedProductTest(test_utils.SequenceLayerTest, spec.NamedProductTest): + pass - @parameterized.parameters( - dict( - first=[('a', 'alpha'), ('b', 'beta')], - second=[('1', 1), ('2', 2), ('3', 3)], - expected=[ - ('a_1', 'alpha', 1), - ('a_2', 'alpha', 2), - ('a_3', 'alpha', 3), - ('b_1', 'beta', 1), - ('b_2', 'beta', 2), - ('b_3', 'beta', 3), - ], - ), - dict( - first=[{'a': 'alpha', 'testcase_name': 'test'}], - second=[('1', 1), ('2', 2)], - expected=[ - ('test_1', 'alpha', 1), - ('test_2', 'alpha', 2), - ], - ), - dict( - first=[ - {'letter': 'a', 'testcase_name': 'alpha'}, - {'testcase_name': 'beta', 'letter': 'b'}, - ], - second=[ - {'testcase_name': 'one', 'number': 1}, - {'number': 2, 'testcase_name': 'two'}, - ], - expected=[ - {'letter': 'a', 'number': 1, 'testcase_name': 'alpha_one'}, - {'letter': 'a', 'number': 2, 'testcase_name': 'alpha_two'}, - {'letter': 'b', 'number': 1, 'testcase_name': 'beta_one'}, - {'letter': 'b', 'number': 2, 'testcase_name': 'beta_two'}, - ], - ), - ) - @mock.patch.object(parameterized, 'named_parameters', autospec=True) - def test_builds_named_products(self, mock_fn, first, second, expected): - test_utils.named_product(first, second) - self.assertSequenceEqual(mock_fn.call_args.args, expected) - @parameterized.parameters( - dict( - first=[{'testcase_name': 'alpha', 'letter': 'a'}, {'letter': 'b'}], - second=[('1', 1), ('2', 2), ('3', 3)], - iterator_without_testcase_name=1, - ), - dict( - first=[{'testcase_name': 'alpha', 'letter': 'a'}], - second=[('1', 1), ()], - iterator_without_testcase_name=2, - ), - ) - def test_raises_on_missing_testcase_names( - self, first, second, iterator_without_testcase_name - ): - with self.assertRaisesRegex( - ValueError, str(iterator_without_testcase_name) - ): - test_utils.named_product(first, second) +class ZipLongestTest(test_utils.SequenceLayerTest, spec.ZipLongestTest): + pass class Shear2dTest(test_utils.SequenceLayerTest): @parameterized.named_parameters( - dict( - testcase_name='basic_3x3', - input_array=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], - expected_output=[ + { + 'testcase_name': 'basic_3x3', + 'input_array': [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + 'expected_output': [ [0, 0, 1, 2, 3], [0, 4, 5, 6, 0], [7, 8, 9, 0, 0], ], - ), - dict( - testcase_name='rect_more_rows', - input_array=[[1, 2], [3, 4], [5, 6]], - expected_output=[[0, 0, 1, 2], [0, 3, 4, 0], [5, 6, 0, 0]], - ), - dict( - testcase_name='rect_more_cols', - input_array=[[1, 2, 3, 4], [5, 6, 7, 8]], - expected_output=[[0, 1, 2, 3, 4], [5, 6, 7, 8, 0]], - ), - dict( - testcase_name='single_row', - input_array=[[1, 2, 3]], - expected_output=[[1, 2, 3]], - ), - dict( - testcase_name='single_col', - input_array=[[1], [2], [3]], - expected_output=[[0, 0, 1], [0, 2, 0], [3, 0, 0]], - ), - dict( - testcase_name='with_zeros', - input_array=[[0, 1], [0, 0]], - expected_output=[[0, 0, 1], [0, 0, 0]], - ), + }, + { + 'testcase_name': 'rect_more_rows', + 'input_array': [[1, 2], [3, 4], [5, 6]], + 'expected_output': [[0, 0, 1, 2], [0, 3, 4, 0], [5, 6, 0, 0]], + }, + { + 'testcase_name': 'rect_more_cols', + 'input_array': [[1, 2, 3, 4], [5, 6, 7, 8]], + 'expected_output': [[0, 1, 2, 3, 4], [5, 6, 7, 8, 0]], + }, + { + 'testcase_name': 'single_row', + 'input_array': [[1, 2, 3]], + 'expected_output': [[1, 2, 3]], + }, + { + 'testcase_name': 'single_col', + 'input_array': [[1], [2], [3]], + 'expected_output': [[0, 0, 1], [0, 2, 0], [3, 0, 0]], + }, + { + 'testcase_name': 'with_zeros', + 'input_array': [[0, 1], [0, 0]], + 'expected_output': [[0, 0, 1], [0, 0, 0]], + }, ) def test_shear_2d(self, input_array, expected_output): - output = test_utils._shear_2d(np.array(input_array)) - self.assertAllEqual(output, np.array(expected_output)) + output = test_utils._shear_2d(jnp.array(input_array)) # pylint: disable=protected-access + self.assertAllEqual(output, jnp.array(expected_output)) if __name__ == '__main__': diff --git a/sequence_layers/jax/types_test.py b/sequence_layers/jax/types_test.py index 7b7b5cf..89df70e 100644 --- a/sequence_layers/jax/types_test.py +++ b/sequence_layers/jax/types_test.py @@ -14,7 +14,7 @@ """Types test.""" import dataclasses -import typing +from typing import Sequence import chex import flax.linen as nn @@ -30,9 +30,7 @@ from sequence_layers.specs import types_behaviors as spec -class ModuleInterfaceTest( - test_utils.SequenceLayerTest, spec.ModuleInterfaceTest -): +class ModuleSpecTest(test_utils.SequenceLayerTest, spec.ModuleSpecTest): pass @@ -197,16 +195,13 @@ def fn(x: types.Sequence) -> types.Sequence: self.assertSequencesEqual(y, x) -class SequenceLayerConfigTest( - test_utils.SequenceLayerTest, spec.SequenceLayerConfigTest -): - pass +class SequenceLayerConfigTest(test_utils.SequenceLayerTest, spec.SequenceLayerConfigTest): def test_copy_raises_on_mutable_attribute(self): @dataclasses.dataclass(slots=True) class ConfigWithSequence(types.SequenceLayerConfig): - seq: typing.Sequence[int] + seq: Sequence[int] def make(self) -> simple.Identity: return simple.Identity.Config().make() @@ -261,15 +256,11 @@ class EmittingTest(test_utils.SequenceLayerTest, spec.EmittingTest): pass -class StatelessEmittingTest( - test_utils.SequenceLayerTest, spec.StatelessEmittingTest -): +class StatelessEmittingTest(test_utils.SequenceLayerTest, spec.StatelessEmittingTest): pass -class StatelessPointwiseFunctorTest( - test_utils.SequenceLayerTest, spec.StatelessPointwiseFunctorTest -): +class StatelessPointwiseFunctorTest(test_utils.SequenceLayerTest, spec.StatelessPointwiseFunctorTest): pass diff --git a/sequence_layers/mlx/__init__.py b/sequence_layers/mlx/__init__.py index 95f38af..4c861f5 100644 --- a/sequence_layers/mlx/__init__.py +++ b/sequence_layers/mlx/__init__.py @@ -13,7 +13,10 @@ # limitations under the License. """Sequence layers in MLX.""" -from sequence_layers.mlx.types import * +# (re-export the names for typechecking) +from . import backend as backend +from . import types as types +from . import test_utils as test_utils +from .test_utils import SequenceLayerTest -from . import backend -from . import types +from sequence_layers.mlx.types import * diff --git a/sequence_layers/mlx/backend.py b/sequence_layers/mlx/backend.py index f73847a..4dd649c 100644 --- a/sequence_layers/mlx/backend.py +++ b/sequence_layers/mlx/backend.py @@ -13,6 +13,7 @@ class BackendWrapper(spec.xp): bool_ = mx.bool_ int32 = mx.int32 + float32 = mx.float32 @override def array(self, a, dtype=None) -> types_spec.Array: @@ -22,5 +23,9 @@ def array(self, a, dtype=None) -> types_spec.Array: def zeros(self, shape, dtype=None) -> types_spec.Array: return mx.zeros(shape, dtype=dtype) + @override + def concatenate(self, arrays, axis=0) -> types_spec.Array: + return mx.concatenate(arrays, axis=axis) + xp: spec.xp = BackendWrapper() diff --git a/sequence_layers/mlx/backend_test.py b/sequence_layers/mlx/backend_test.py new file mode 100644 index 0000000..4c8ab5f --- /dev/null +++ b/sequence_layers/mlx/backend_test.py @@ -0,0 +1,14 @@ +"""Tests for MLX backend utilities.""" + +from absl.testing import absltest + +from sequence_layers.mlx import test_utils +from sequence_layers.specs import backend_behaviors as spec + + +class ModuleSpecTest(test_utils.SequenceLayerTest, spec.ModuleSpecTest): + pass + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/test_utils.py b/sequence_layers/mlx/test_utils.py index 9953b74..3af90ca 100644 --- a/sequence_layers/mlx/test_utils.py +++ b/sequence_layers/mlx/test_utils.py @@ -1,14 +1,70 @@ """Test utilities for MLX sequence layers.""" -from typing import override +from typing import Any, Callable, Iterable, Mapping, override +from typing import Sequence as TypingSequence +from typing import TypeVar +from absl.testing import absltest import mlx.core as mx import numpy as np +from sequence_layers import specs from sequence_layers.mlx import types import sequence_layers.mlx as sl from sequence_layers.specs import test_utils as spec +Sequence = types.Sequence +MaskedSequence = types.MaskedSequence +ShapeDType = types.ShapeDType + +_T = TypeVar('_T') +_TestFnT = Callable[..., None] + + +def zip_longest( + targets: Iterable[Iterable[Any]], + sources: Iterable[_T], +) -> list[_T]: + """Applies zip_longest, specialized to @parameterized's argument format. + + Args: + targets: Iterable of parameterized test arguments. + sources: Iterable of parameterized test arguments. If `targets` is a mapping + `sources` must be a mapping as well. + + Returns: + A list of the zipped arguments, of the type of `targets` and with each + zipped argument internally sorted (target, source). If either input sequence + was longer, the last element of the shorter input sequence is repeated. + """ + return spec.zip_longest(targets, sources) + + +def named_product( + first: Iterable[TypingSequence[Any] | Mapping[str, Any]], + second: Iterable[TypingSequence[Any] | Mapping[str, Any]], +) -> Callable[[_TestFnT], _TestFnT]: + """Builds named parameters from the product of iterators of named parameters. + + As in parameterized.named_parameters, if an iterator's items are sequences, + the first element is interpreted as the name. If an iterator's items are + mappings, the `testcase_name` key is used. + + Args: + first: Iterable of named parameters, whose names will be the first part of + the named product's test names. + second: Iterable of named parameters, whose names will be the second part of + the named product's test names. + + Returns: + A decorator that calls the test function with the cartesian product of the + given iterators, whose items are named parameters with names of the form + `{first_item_name}_{second_item_name}`. If both iterators' items are + mappings, the product's items are mappings; otherwise they are ordered + tuples. + """ + return spec.named_product(first, second) + def _mask_and_pad_to_max_length( a: types.Sequence, b: types.Sequence @@ -30,6 +86,41 @@ class SequenceLayerTest(spec.SequenceLayerTest): sl = sl # pyrefly: ignore[bad-assignment] # module-as-protocol + @override + def setUp(self): + super().setUp() + # To avoid flakes, fix random seeds. + # MLX doesn't have a global seed, but we can set numpy seed. + np.random.seed(123456789) + + @override + def random_sequence( + self, + *dims: int, + dtype=None, + random_mask: bool = False, + random_lengths: bool | None = None, + low: int | None = 0, + high: int | None = 10, + low_length: int = 0, + high_length: int | None = None, + ) -> types.Sequence: + if len(dims) < 2: + raise ValueError('dims must be at least (batch, time)') + batch_size = dims[0] + time = dims[1] + shape = dims[2:] + + values_np = np.random.normal(size=(batch_size, time) + shape).astype( + np.float32 + ) + values = mx.array(values_np, dtype=dtype or mx.float32) + + mask_np = np.ones((batch_size, time), dtype=bool) + mask = mx.array(mask_np, dtype=mx.bool_) + + return types.Sequence(values, mask) + @override def assertAllEqual(self, x, y): """Asserts that two arrays are equal.""" @@ -45,3 +136,121 @@ def assertSequencesEqual( # pyrefly: ignore[bad-override] x, y = _mask_and_pad_to_max_length(x, y) self.assertAllEqual(x.values, y.values) self.assertAllEqual(x.mask, y.mask) + + @override + # pyrefly: ignore[bad-override] + def _step_by_step( + self, + layer: types.SequenceLayer, + x: types.Sequence, + *, + block_size: int = 1, + constants=None, + stream_constants: bool = False, + stream_constants_list: list[Any] | None = None, + ) -> tuple[types.Sequence, Any]: + batch = x.values.shape[0] if hasattr(x, 'values') else x.shape[0] + time = x.values.shape[1] if hasattr(x, 'values') else x.shape[1] + + input_spec = types.ShapeDType(x.channel_shape, x.dtype) + + init_constants = dict(constants) if constants else {} + + state = layer.get_initial_state( + batch, input_spec, constants=init_constants or None, training=False + ) + + outputs_values = [] + outputs_masks = [] + + for t in range(0, time, block_size): + x_block = sl.Sequence( + x.values[:, t : t + block_size], + x.mask[:, t : t + block_size], + ) + + step_constants = dict(constants) if constants else {} + if stream_constants and stream_constants_list: + step_idx = t // block_size + if step_idx < len(stream_constants_list): + step_constants.update(stream_constants_list[step_idx]) + + y_block, state = layer.step( + x_block, state, constants=step_constants or None, training=False + ) + outputs_values.append(y_block.values) + outputs_masks.append(y_block.mask) + + y_values = mx.concatenate(outputs_values, axis=1) + y_mask = mx.concatenate(outputs_masks, axis=1) + + return sl.Sequence(y_values, y_mask), state + + @override + # pyrefly: ignore[bad-override] + def verify_contract( + self, + l: types.SequenceLayer, + x: types.Sequence, + *, + training: bool = False, + constants=None, + stream_constants: bool = False, + stream_constants_list: list[Any] | None = None, + atol: float = 1e-5, + rtol: float = 1e-5, + **kwargs, + ) -> types.Sequence: + if hasattr(x, 'channel_shape'): + input_shape = x.channel_shape + elif hasattr(x, 'shape'): + input_shape = x.shape[2:] + else: + raise ValueError(f'Cannot determine input shape from {x}') + dtype = x.dtype if hasattr(x, 'dtype') else self.xp.float32 + + y_layer = l.layer(x, training=training, constants=constants) + + expected_shape = l.get_output_shape(input_shape, constants=constants) + self.assertEqual(y_layer.channel_shape, expected_shape) + + expected_dtype = l.get_output_dtype(dtype, constants=constants) + self.assertEqual(y_layer.dtype, expected_dtype) + + if not l.supports_step: + return y_layer + + block_size = l.block_size + y_step, _ = self._step_by_step( + l, + x, + block_size=block_size, + constants=constants, + stream_constants=stream_constants, + stream_constants_list=stream_constants_list, + ) + + self.assertEqual(y_step.shape, y_layer.shape) + self.assertSequencesClose(y_layer, y_step, atol=atol, rtol=rtol) + + return y_layer + + @override + def assertSequencesClose(self, x: Any, y: Any, **kwargs) -> None: + x_np = np.array(x.values) if hasattr(x, 'values') else np.array(x) + y_np = np.array(y.values) if hasattr(y, 'values') else np.array(y) + np.testing.assert_allclose(x_np, y_np, **kwargs) + if hasattr(x, 'mask') and hasattr(y, 'mask'): + mask_x = np.array(x.mask) + mask_y = np.array(y.mask) + np.testing.assert_array_equal(mask_x, mask_y) + + +class ModuleSpecTest(SequenceLayerTest, spec.ModuleSpecTest): + + @override + def module_spec_pairs(self, backend_sl: specs.ModuleSpec): + return {backend_sl.test_utils: spec.ModuleSpec} + + +main = absltest.main diff --git a/sequence_layers/mlx/test_utils_mlx.py b/sequence_layers/mlx/test_utils_mlx.py deleted file mode 100644 index 5f8dd09..0000000 --- a/sequence_layers/mlx/test_utils_mlx.py +++ /dev/null @@ -1,174 +0,0 @@ -"""Test utilities for MLX sequence layers (legacy from branch).""" - -import mlx.core as mx -import numpy as np - -from sequence_layers.mlx import basic_types as bt - -Sequence = bt.Sequence -MaskedSequence = bt.MaskedSequence -ShapeDType = bt.ShapeDType - - -def random_sequence( - batch: int, - time: int, - channels: int | tuple[int, ...], - *, - dtype=mx.float32, - mask: mx.array | None = None, - masked: bool = True, -) -> Sequence: - """Create a random Sequence for testing. - - Args: - batch: Batch size. - time: Sequence length. - channels: Channel size (int) or channel shape (tuple). - dtype: Values dtype. - mask: Optional explicit mask. If None, all-valid mask is used. - masked: If True, returns a MaskedSequence. If False, a Sequence. - - Returns: - A random Sequence or MaskedSequence. - """ - if isinstance(channels, int): - channels = (channels,) - shape = (batch, time) + channels - values = mx.random.normal(shape=shape).astype(dtype) - if mask is None: - mask = mx.ones((batch, time), dtype=mx.bool_) - if masked: - return MaskedSequence(values, mask) - return Sequence(values, mask) - - -def step_by_step( - layer, - x: Sequence, - *, - block_size: int = 1, - constants=None, - stream_constants=None, -) -> tuple[Sequence, object]: - """Run a layer step-by-step and concatenate outputs. - - Args: - layer: A SequenceLayer with supports_step. - x: Input sequence [batch, time, ...]. - block_size: Number of timesteps per step. - constants: Optional constants dict (static, passed as-is each step). - stream_constants: Optional dict of source_name -> Sequence. These are - sliced at the same block_size as input for each step, merging into - the constants dict. Use this for streaming cross-attention sources. - - Returns: - (output_sequence, final_state) - """ - batch = x.shape[0] - time = x.shape[1] - spec = x.channel_spec - - # Build initial constants with full stream sources for get_initial_state. - init_constants = dict(constants) if constants else {} - if stream_constants: - init_constants.update(stream_constants) - - state = layer.get_initial_state(batch, spec, constants=init_constants or None) - - outputs_values = [] - outputs_masks = [] - - for t in range(0, time, block_size): - x_block = Sequence( - x.values[:, t : t + block_size], - x.mask[:, t : t + block_size], - ) - - # Build per-step constants with sliced stream sources. - step_constants = dict(constants) if constants else {} - if stream_constants: - for name, seq in stream_constants.items(): - step_constants[name] = Sequence( - seq.values[:, t : t + block_size], - seq.mask[:, t : t + block_size], - ) - - y_block, state = layer.step( - x_block, - state, - constants=step_constants or None, - ) - outputs_values.append(y_block.values) - outputs_masks.append(y_block.mask) - - y_values = mx.concatenate(outputs_values, axis=1) - y_mask = mx.concatenate(outputs_masks, axis=1) - return Sequence(y_values, y_mask), state - - -def verify_contract( - test_case, - layer, - input_shape, - *, - batch_size: int = 2, - time: int = 8, - dtype=mx.float32, - constants=None, - atol: float = 1e-5, - rtol: float = 1e-5, - test_step: bool = True, -): - """Verify that a layer's layer() and step() outputs are consistent. - - Checks: - 1. layer() runs without error and produces correct output shape. - 2. step() runs without error and produces correct output shape. - 3. layer() and step() produce approximately equal outputs. - - Args: - test_case: An absltest.TestCase (or similar) with assertion methods. - layer: The SequenceLayer to test. - input_shape: Channel shape (tuple), e.g. (16,). - batch_size: Batch size for test inputs. - time: Sequence length for test inputs. - dtype: Input dtype. - constants: Optional constants dict. - atol: Absolute tolerance for output comparison. - rtol: Relative tolerance for output comparison. - test_step: Whether to test step() and compare with layer(). - """ - x = random_sequence(batch_size, time, input_shape, dtype=dtype) - - # Test layer(). - y_layer = layer.layer(x, constants=constants) - - # Check output shape. - expected_shape = layer.get_output_shape(input_shape, constants=constants) - test_case.assertEqual(y_layer.channel_shape, expected_shape) - - # Check output dtype. - expected_dtype = layer.get_output_dtype(dtype, constants=constants) - test_case.assertEqual(y_layer.dtype, expected_dtype) - - if not test_step or not layer.supports_step: - return - - # Test step(). - block_size = layer.block_size - y_step, _ = step_by_step(layer, x, block_size=block_size, constants=constants) - - # Check shapes match. - test_case.assertEqual(y_step.shape, y_layer.shape) - - # Check values match. - y_layer_np = np.array(y_layer.values) - y_step_np = np.array(y_step.values) - np.testing.assert_allclose( - y_step_np, - y_layer_np, - atol=atol, - rtol=rtol, - err_msg=f'{layer.__class__.__name__}: step() and layer() outputs differ', - ) diff --git a/sequence_layers/mlx/test_utils_test.py b/sequence_layers/mlx/test_utils_test.py new file mode 100644 index 0000000..733bf3c --- /dev/null +++ b/sequence_layers/mlx/test_utils_test.py @@ -0,0 +1,24 @@ +"""Tests for the test utilities.""" + +from sequence_layers.mlx import test_utils +from sequence_layers.specs import test_utils_behaviors as spec + + +class ModuleSpecTest(test_utils.SequenceLayerTest, spec.ModuleSpecTest): + pass + + +class NamedProductTest(test_utils.SequenceLayerTest, spec.NamedProductTest): + pass + + +class ZipLongestTest(test_utils.SequenceLayerTest, spec.ZipLongestTest): + pass + + +class VerifyContractTest(test_utils.SequenceLayerTest, spec.VerifyContractTest): + pass + + +if __name__ == '__main__': + test_utils.main() diff --git a/sequence_layers/mlx/types.py b/sequence_layers/mlx/types.py index 80121d9..2c82766 100644 --- a/sequence_layers/mlx/types.py +++ b/sequence_layers/mlx/types.py @@ -6,16 +6,8 @@ import functools import math import types -from typing import ( - Any, - Callable, - cast, - Iterable, - MutableMapping, - override, - Self, - TypeVar, -) +from typing import (Any, Callable, cast, Iterable, MutableMapping, override, + Self, TypeVar) import jaxtyping as jt from mlx import nn diff --git a/sequence_layers/mlx/types_test.py b/sequence_layers/mlx/types_test.py index da6163c..1b7aadd 100644 --- a/sequence_layers/mlx/types_test.py +++ b/sequence_layers/mlx/types_test.py @@ -6,9 +6,7 @@ from sequence_layers.specs import types_behaviors as spec -class ModuleInterfaceTest( - test_utils.SequenceLayerTest, spec.ModuleInterfaceTest -): +class ModuleSpecTest(test_utils.SequenceLayerTest, spec.ModuleSpecTest): pass diff --git a/sequence_layers/specs/__init__.py b/sequence_layers/specs/__init__.py index 0060178..d523988 100644 --- a/sequence_layers/specs/__init__.py +++ b/sequence_layers/specs/__init__.py @@ -1,10 +1,15 @@ # https://typing.python.org/en/latest/spec/protocol.html#modules-as-implementations-of-protocols -from typing import Protocol, runtime_checkable +from typing import Protocol, runtime_checkable, TYPE_CHECKING from . import backend as _backend from . import types as _types +# Import test_utils only for type checking to avoid circular imports, +# as test_utils.py imports specs.ModuleSpec defined below. +if TYPE_CHECKING: + from . import test_utils as _test_utils + @runtime_checkable class ModuleSpec(Protocol): @@ -18,6 +23,10 @@ def backend(self) -> _backend.ModuleSpec: def types(self) -> _types.ModuleSpec: ... + @property + def test_utils(self) -> '_test_utils.ModuleSpec': + ... + # Identifiers that backend-specific implementations should expose at top level. # Demonstrating read-only allows for covariance (subclasses of types_module.Sequence to satisfy the protocol). @@ -36,3 +45,7 @@ def SequenceLayer(self) -> type[_types.SequenceLayer]: @property def SequenceLayerConfig(self) -> type[_types.SequenceLayerConfig]: ... + + @property + def SequenceLayerTest(self) -> type: + ... diff --git a/sequence_layers/specs/backend.py b/sequence_layers/specs/backend.py index fb64595..d65edfe 100644 --- a/sequence_layers/specs/backend.py +++ b/sequence_layers/specs/backend.py @@ -17,6 +17,7 @@ class xp(Protocol): bool_: Any int32: Any + float32: Any def array(self, a: Any, dtype: Any = None) -> Array: """Creates an array.""" @@ -24,6 +25,9 @@ def array(self, a: Any, dtype: Any = None) -> Array: def zeros(self, shape: tuple[int, ...], dtype: Any = None) -> Array: """Creates an array of zeros.""" + def concatenate(self, arrays: list[Array], axis: int = 0) -> Array: + ... + @runtime_checkable class ModuleSpec(Protocol): diff --git a/sequence_layers/specs/backend_behaviors.py b/sequence_layers/specs/backend_behaviors.py new file mode 100644 index 0000000..d72826c --- /dev/null +++ b/sequence_layers/specs/backend_behaviors.py @@ -0,0 +1,16 @@ +"""Abstract tests for backend utilities.""" + +# pylint: disable=abstract-method + +from typing import override + +from sequence_layers import specs +from sequence_layers.specs import backend as backend_spec +from sequence_layers.specs import test_utils as test_utils_spec + + +class ModuleSpecTest(test_utils_spec.ModuleSpecTest): + + @override + def module_spec_pairs(self, backend_sl: specs.ModuleSpec): + return {backend_sl.backend: backend_spec.ModuleSpec} diff --git a/sequence_layers/specs/test_utils.py b/sequence_layers/specs/test_utils.py index 26b1d8d..b20e529 100644 --- a/sequence_layers/specs/test_utils.py +++ b/sequence_layers/specs/test_utils.py @@ -1,22 +1,137 @@ -"""Test utilities for sequence layers.""" +"""Utilities for testing sequence layers.""" import abc -from typing import Any +import itertools +from typing import Any, Callable, Iterable, Mapping, Protocol, runtime_checkable +from typing import Sequence as TypingSequence +from typing import TypeVar from absl.testing import parameterized +import typeguard from sequence_layers import specs from sequence_layers.specs import backend as backend_spec from sequence_layers.specs import types as types_spec +_T = TypeVar('_T') class _AbcParameterizedTestCaseMeta(abc.ABCMeta, type(parameterized.TestCase)): """Metaclass for abstract parameterized test cases.""" -class SequenceLayerTest[SequenceT: types_spec.Sequence = types_spec.Sequence]( - parameterized.TestCase, metaclass=_AbcParameterizedTestCaseMeta -): + +def zip_longest( + targets: Iterable[Iterable[Any]], + sources: Iterable[_T], +) -> list[_T]: + """Applies zip_longest, specialized to @parameterized's argument format. + + Args: + targets: Iterable of parameterized test arguments. + sources: Iterable of parameterized test arguments. If `targets` is a mapping + `sources` must be a mapping as well. + + Returns: + A list of the zipped arguments, of the type of `targets` and with each + zipped argument internally sorted (target, source). If either input sequence + was longer, the last element of the shorter input sequence is repeated. + """ + results: list[Any] = [] + prev_source, prev_target = None, None + for source, target in itertools.zip_longest(sources, targets): + if source is None: + source = prev_source + if target is None: + target = prev_target + + if isinstance(target, Mapping): + assert isinstance(source, Mapping) + results.append({**target, **source}) + elif isinstance(target, Iterable) and not isinstance(target, (str, bytes)): + if isinstance(source, Mapping): + raise ValueError('Cannot zip mapping source with non-mapping target') + assert isinstance(source, Iterable) + results.append(tuple(target) + tuple(source)) + else: + results.append((target, source)) + + prev_source, prev_target = source, target + + return results + + +_TestFnT = Callable[..., None] + + +def named_product( + first: Iterable[TypingSequence[Any] | Mapping[str, Any]], + second: Iterable[TypingSequence[Any] | Mapping[str, Any]], +) -> Callable[[_TestFnT], _TestFnT]: + """Builds named parameters from the product of iterators of named parameters. + + As in parameterized.named_parameters, if an iterator's items are sequences, + the first element is interpreted as the name. If an iterator's items are + mappings, the `testcase_name` key is used. + + Args: + first: Iterable of named parameters, whose names will be the first part of + the named product's test names. + second: Iterable of named parameters, whose names will be the second part of + the named product's test names. + + Returns: + A decorator that calls the test function with the cartesian product of the + given iterators, whose items are named parameters with names of the form + `{first_item_name}_{second_item_name}`. If both iterators' items are + mappings, the product's items are mappings; otherwise they are ordered + tuples. + """ + results: list[Any] = [] + + for p1, p2 in itertools.product(first, second): + for source, parameters in enumerate([p1, p2]): + if isinstance(parameters, Mapping): + if 'testcase_name' not in parameters: + raise ValueError( + f'Mapping {parameters} from iterable #{source+1} does not have' + ' key `testcase_name`.' + ) + elif not parameters: + raise ValueError( + f'An sequence from iterable #{source+1} is empty; the first entry' + ' is expected to be a testcase name.' + ) + + if isinstance(p1, Mapping) and isinstance(p2, Mapping): + testcase_name = f'{p1["testcase_name"]}_{p2["testcase_name"]}' + p1 = {k: v for k, v in p1.items() if k != 'testcase_name'} + p2 = {k: v for k, v in p2.items() if k != 'testcase_name'} + results.append({**p1, **p2, 'testcase_name': testcase_name}) + else: + if isinstance(p1, Mapping): + p1_name = p1['testcase_name'] + p1 = tuple(v for k, v in p1.items() if k != 'testcase_name') + else: + p1_name = p1[0] + p1 = p1[1:] + + if isinstance(p2, Mapping): + p2_name = p2['testcase_name'] + p2 = tuple(v for k, v in p2.items() if k != 'testcase_name') + else: + p2_name = p2[0] + p2 = p2[1:] + + testcase_name = f'{p1_name}_{p2_name}' + results.append((testcase_name, *p1, *p2)) + + return parameterized.named_parameters(*results) + + +class SequenceLayerTest[ + SequenceT: types_spec.Sequence = types_spec.Sequence, + SequenceLayerT: types_spec.SequenceLayer = types_spec.SequenceLayer, +](parameterized.TestCase, metaclass=_AbcParameterizedTestCaseMeta): """Base test class providing common sequence testing assertions. Binds a backend implementation to tests. @@ -27,13 +142,108 @@ class SequenceLayerTest[SequenceT: types_spec.Sequence = types_spec.Sequence]( @property def xp(self) -> backend_spec.xp: - """Returns the backend module.""" + """Returns the backend wrapper.""" return self.sl.backend.xp @abc.abstractmethod def assertSequencesEqual(self, x: SequenceT, y: SequenceT) -> None: # pylint: disable=invalid-name - """After padding, checks sequence values are equal and masks are equal.""" + """Asserts that two sequences are equal.""" @abc.abstractmethod def assertAllEqual(self, x: Any, y: Any) -> None: # pylint: disable=invalid-name - """Asserts that two arrays are equal.""" + """Asserts that all elements are equal.""" + + @abc.abstractmethod + def random_sequence( + self, + *dims: int, + dtype=None, + random_mask: bool = False, + random_lengths: bool | None = None, + low: int | None = 0, + high: int | None = 10, + low_length: int = 0, + high_length: int | None = None, + ) -> SequenceT: + """Generates a random sequence.""" + + @abc.abstractmethod + def _step_by_step( + self, + layer: types_spec.SequenceLayer, + x: types_spec.Sequence, + *, + block_size: int = 1, + constants=None, + stream_constants=None, + ) -> tuple[types_spec.Sequence, Any]: + """Runs a layer step by step.""" + + @abc.abstractmethod + def verify_contract( + self, + l: SequenceLayerT, + x: SequenceT, + *, + training: bool = False, + constants=None, + stream_constants: bool = False, + stream_constants_list: list[Any] | None = None, + atol: float = 1e-5, + rtol: float = 1e-5, + **kwargs, + ) -> SequenceT: + """Verifies that a layer satisfies the contract.""" + + @abc.abstractmethod + def assertSequencesClose(self, x: Any, y: Any, **kwargs) -> None: # pylint: disable=invalid-name + """Asserts that two sequences are close.""" + + +class ModuleSpecTest(SequenceLayerTest): + """Test that a backend-specific module implements the ModuleSpec protocol.""" + + @abc.abstractmethod + def module_spec_pairs(self, backend_sl: specs.ModuleSpec) -> dict[Any, Any]: + """Returns a mapping of module to protocol to be verified.""" + + def test_backend_specific_module_has_interface(self) -> None: + pairs = self.module_spec_pairs(self.sl) + for mod, protocol in pairs.items(): + self.assertIsInstance(mod, protocol) + + def test_module_spec_with_typeguard(self) -> None: + pairs = self.module_spec_pairs(self.sl) + for mod, protocol in pairs.items(): + typeguard.check_type('backend_module', mod, protocol) + + +@runtime_checkable +class ModuleSpec(Protocol): + """Specification for sequence_layers..test_utils""" + + def zip_longest( + self, + targets: Iterable[Iterable[Any]], + sources: Iterable[Any], + ) -> list[Any]: + """Zips targets and sources.""" + + def named_product( + self, + first: Iterable[Any], + second: Iterable[Any], + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Creates a named product.""" + + @property + def SequenceLayerTest(self) -> type: # pylint: disable=invalid-name + ... + + +__all__ = [ + name + for name, attr in ModuleSpec.__dict__.items() + if isinstance(attr, property) + or (callable(attr) and not name.startswith('__')) +] diff --git a/sequence_layers/specs/test_utils_behaviors.py b/sequence_layers/specs/test_utils_behaviors.py new file mode 100644 index 0000000..a80b161 --- /dev/null +++ b/sequence_layers/specs/test_utils_behaviors.py @@ -0,0 +1,324 @@ +"""Abstract tests for test utilities.""" + +# pylint: disable=abstract-method + +import fractions +from typing import Any, override +from unittest import mock + +from absl.testing import parameterized +import numpy as np + +from sequence_layers import specs +from sequence_layers.specs import test_utils as test_utils_spec +from sequence_layers.specs import types as types_spec + + +class ModuleSpecTest(test_utils_spec.ModuleSpecTest): + + @override + def module_spec_pairs(self, backend_sl: specs.ModuleSpec): + return {backend_sl.test_utils: test_utils_spec.ModuleSpec} + + +class NamedProductTest(test_utils_spec.SequenceLayerTest): + + @parameterized.parameters( + { + 'first': [('a', 'alpha'), ('b', 'beta')], + 'second': [('1', 1), ('2', 2), ('3', 3)], + 'expected': [ + ('a_1', 'alpha', 1), + ('a_2', 'alpha', 2), + ('a_3', 'alpha', 3), + ('b_1', 'beta', 1), + ('b_2', 'beta', 2), + ('b_3', 'beta', 3), + ], + }, + { + 'first': [{'a': 'alpha', 'testcase_name': 'test'}], + 'second': [('1', 1), ('2', 2)], + 'expected': [ + ('test_1', 'alpha', 1), + ('test_2', 'alpha', 2), + ], + }, + { + 'first': [ + {'letter': 'a', 'testcase_name': 'alpha'}, + {'testcase_name': 'beta', 'letter': 'b'}, + ], + 'second': [ + {'testcase_name': 'one', 'number': 1}, + {'number': 2, 'testcase_name': 'two'}, + ], + 'expected': [ + {'letter': 'a', 'number': 1, 'testcase_name': 'alpha_one'}, + {'letter': 'a', 'number': 2, 'testcase_name': 'alpha_two'}, + {'letter': 'b', 'number': 1, 'testcase_name': 'beta_one'}, + {'letter': 'b', 'number': 2, 'testcase_name': 'beta_two'}, + ], + }, + ) + @mock.patch.object(parameterized, 'named_parameters', autospec=True) + def test_builds_named_products(self, mock_fn, first, second, expected): + self.sl.test_utils.named_product(first, second) + self.assertSequenceEqual(mock_fn.call_args.args, expected) + + @parameterized.parameters( + { + 'first': [{'testcase_name': 'alpha', 'letter': 'a'}, {'letter': 'b'}], + 'second': [('1', 1), ('2', 2), ('3', 3)], + 'iterator_without_testcase_name': 1, + }, + { + 'first': [{'testcase_name': 'alpha', 'letter': 'a'}], + 'second': [('1', 1), ()], + 'iterator_without_testcase_name': 2, + }, + ) + def test_raises_on_missing_testcase_names( + self, first, second, iterator_without_testcase_name + ): + with self.assertRaisesRegex( + ValueError, str(iterator_without_testcase_name) + ): + self.sl.test_utils.named_product(first, second) + + +class ZipLongestTest(test_utils_spec.SequenceLayerTest): + + @parameterized.parameters( + { + 'targets': [('a',), ('b',)], + 'sources': [(1,), (2,)], + 'expected': [('a', 1), ('b', 2)], + }, + { + 'targets': [('a',), ('b',)], + 'sources': [(1,)], + 'expected': [('a', 1), ('b', 1)], + }, + { + 'targets': [('a',)], + 'sources': [(1,), (2,)], + 'expected': [('a', 1), ('a', 2)], + }, + { + 'targets': [{'testcase_name': 'a'}], + 'sources': [{'val': 1}], + 'expected': [{'testcase_name': 'a', 'val': 1}], + }, + ) + def test_zip_longest(self, targets, sources, expected): + results = self.sl.test_utils.zip_longest(targets, sources) + self.assertEqual(results, expected) + + +class GenericDummyLayer(types_spec.SequenceLayer): + """Generic dummy layer for testing verify_contract.""" + + @override + def layer( + self, + x: types_spec.Sequence, + *, + training: bool, + constants: types_spec.Constants | None = None, + ) -> types_spec.Sequence: + return x + + @override + def step( + self, + x: types_spec.Sequence, + state: types_spec.State, + *, + training: bool, + constants: types_spec.Constants | None = None, + ) -> tuple[types_spec.Sequence, types_spec.State]: + return x, state + + @override + def step_with_emits( + self, + x: types_spec.Sequence, + state: types_spec.State, + *, + training: bool, + constants: types_spec.Constants | None = None, + ) -> tuple[types_spec.Sequence, types_spec.State, types_spec.Emits]: + y, state = self.step(x, state, constants=constants, training=training) + return y, state, () + + @override + def get_initial_state( + self, + batch_size: int, + input_spec: types_spec.ChannelSpec, + *, + training: bool, + constants: types_spec.Constants | None = None, + ) -> types_spec.State: + return None + + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) + + @property + @override + def block_size(self) -> int: + return 1 + + @property + @override + def output_ratio(self) -> fractions.Fraction: + return fractions.Fraction(1) + + @property + @override + def input_latency(self) -> int: + return 0 + + @property + @override + def output_latency(self) -> int: + return 0 + + @property + @override + def supports_step(self) -> bool: + return True + + @override + def get_accumulated_input_latency(self, input_latency: int) -> int: + return input_latency + + @override + def get_accumulated_output_latency(self, output_latency: int) -> int: + return output_latency + + @override + def layer_with_emits( + self, + x: types_spec.Sequence, + *, + training: bool, + constants: types_spec.Constants | None = None, + ) -> tuple[types_spec.Sequence, types_spec.Emits]: + return self.layer(x, training=training, constants=constants), () + + @override + def get_output_shape( + self, + input_shape: types_spec.ShapeLike, + *, + constants: types_spec.Constants | None = None, + ) -> types_spec.Shape: + return tuple(input_shape) + + @override + def get_output_dtype( + self, + input_dtype: types_spec.DType, + *, + constants: types_spec.Constants | None = None, + ) -> types_spec.DType: + return input_dtype + + @override + def get_output_spec( + self, + input_spec: Any, + *, + constants: types_spec.Constants | None = None, + ) -> Any: + shape = self.get_output_shape(input_spec.shape, constants=constants) + dtype = self.get_output_dtype(input_spec.dtype, constants=constants) + + class Spec: + """Dummy spec class.""" + + def __init__(self, s, d): + self.shape = s + self.dtype = d + + return Spec(shape, dtype) + + +class GenericMismatchedDummyLayer(GenericDummyLayer): + """Dummy layer that induces a mismatch by returning zeros in layer().""" + + @override + def layer( + self, + x: types_spec.Sequence, + *, + training: bool, + constants: types_spec.Constants | None = None, + ) -> types_spec.Sequence: + return x.apply_values(lambda v: v * 0.0) + + +class VerifyContractTest(test_utils_spec.SequenceLayerTest): + """Abstract tests for verify_contract.""" + + def get_dummy_layer(self, mismatch: bool) -> Any: + """Returns a dummy layer for testing.""" + backend_sl = self.sl + + if mismatch: + + class BackendMismatchedDummyLayer( + GenericMismatchedDummyLayer, backend_sl.types.SequenceLayer + ): + """Mismatched dummy layer for backend.""" + + return BackendMismatchedDummyLayer() + + class BackendDummyLayer(GenericDummyLayer, backend_sl.types.SequenceLayer): + """Dummy layer for backend.""" + + return BackendDummyLayer() + + def test_verify_contract_catches_step_mismatch(self): + layer = self.get_dummy_layer(mismatch=True) + + x = self.sl.Sequence( + self.xp.array(np.ones((2, 5, 10))), + self.xp.array(np.ones((2, 5), dtype=bool)), + ) + + with self.assertRaises(AssertionError): + self.verify_contract(layer, x, training=False) + + def test_verify_contract_succeeds_when_equivalent(self): + layer = self.get_dummy_layer(mismatch=False) + + x = self.sl.Sequence( + self.xp.array(np.ones((2, 5, 10))), + self.xp.array(np.ones((2, 5), dtype=bool)), + ) + + self.verify_contract(layer, x, training=False) + + def test_verify_contract_handles_stream_constants(self): + layer = self.get_dummy_layer(mismatch=False) + + x = self.sl.Sequence( + self.xp.array(np.ones((2, 5, 10))), + self.xp.array(np.ones((2, 5), dtype=bool)), + ) + constants = { + 'c': self.sl.Sequence( + self.xp.array(np.ones((2, 5, 1))), + self.xp.array(np.ones((2, 5), dtype=bool)), + ) + } + + self.verify_contract( + layer, x, training=False, constants=constants, stream_constants=True + ) diff --git a/sequence_layers/specs/types_behaviors.py b/sequence_layers/specs/types_behaviors.py index a1bb773..415da51 100644 --- a/sequence_layers/specs/types_behaviors.py +++ b/sequence_layers/specs/types_behaviors.py @@ -1,5 +1,5 @@ # pylint: disable=abstract-method -"""Generic tests for Sequence types.""" +"""Generic tests for Sequence types_spec.""" import dataclasses import fractions @@ -9,10 +9,34 @@ from absl.testing import parameterized import numpy as np +from sequence_layers import specs +from sequence_layers.specs import test_utils as test_utils_spec from sequence_layers.specs import types as types_spec from sequence_layers.specs.test_utils import SequenceLayerTest +class ModuleSpecTest(test_utils_spec.ModuleSpecTest): + + @override + def module_spec_pairs(self, backend_sl: specs.ModuleSpec): + return {backend_sl.types: types_spec.ModuleSpec} + + def test_backend_specific_types_are_subclasses(self) -> None: + pairs = self.module_spec_pairs(self.sl) + for mod, protocol in pairs.items(): + if protocol is types_spec.ModuleSpec: + self.assertTrue(issubclass(mod.Sequence, types_spec.Sequence)) + self.assertTrue( + issubclass(mod.MaskedSequence, types_spec.MaskedSequence) + ) + self.assertTrue(issubclass(mod.SequenceLayer, types_spec.SequenceLayer)) + self.assertTrue( + issubclass(mod.SequenceLayerConfig, types_spec.SequenceLayerConfig) + ) + self.assertTrue(issubclass(mod.Steppable, types_spec.Steppable)) +>>>>>>> 6ad10e5 (refactor(test_utils): Abstract into spec and implementations.) + + class DummyChannelSpec(NamedTuple): """Dummy channel spec for testing.""" @@ -155,7 +179,7 @@ def test_backend_specific_module_has_interface(self) -> None: class SequenceTest(SequenceLayerTest): - """Generic tests for the Sequence class.""" + """Abstract tests for the Sequence class.""" @parameterized.named_parameters( ('mask_value=None', 0.0, None), @@ -377,7 +401,8 @@ def test_from_lengths(self) -> None: self.assertIsInstance(x, self.sl.MaskedSequence) -class SteppableTest(SequenceLayerTest): +class SteppableTest(test_utils_spec.SequenceLayerTest): + """Abstract tests for Steppable layers.""" def create_steppable(self) -> types_spec.Steppable: """Creates a basic Steppable instance.""" @@ -449,7 +474,7 @@ def test_steppable_with_emits_defaults_to_tuple_with_empty_emits( mock_step.assert_called_with(seq, state_in, training=True, constants=None) -class SequenceLayerConfigTest(SequenceLayerTest): +class SequenceLayerConfigTest(test_utils_spec.SequenceLayerTest): def test_copy(self) -> None: backend_sl = self.sl @@ -466,7 +491,12 @@ def make(self) -> Any: """Makes a dummy layer.""" return 'dummy_layer' - config = Config() # type: ignore + @override + def copy(self, **kwargs: Any) -> Any: + """Returns a copy of the config.""" + return dataclasses.replace(self, **kwargs) + + config = Config() new_config = config.copy(b='new string') self.assertEqual(new_config.a, config.a) self.assertEqual(new_config.b, 'new string') @@ -482,7 +512,12 @@ def make(self) -> Any: """Makes a dummy layer.""" return 'dummy_layer' - config = NonDataclassConfig() # type: ignore + @override + def copy(self, **kwargs: Any) -> Any: + """Returns a copy of the config.""" + raise TypeError('Mock non-dataclass config') + + config = NonDataclassConfig() with self.assertRaises(TypeError): new_config = config.copy() del new_config @@ -499,7 +534,12 @@ def make(self) -> Any: """Makes a dummy layer.""" return 'dummy_layer' - config = Config() # type: ignore + @override + def copy(self, **kwargs: Any) -> Any: + """Returns a copy of the config.""" + return dataclasses.replace(self, **kwargs) + + config = Config() # dataclasses.replace raises TypeError for unknown arguments # JAX implementation wraps it in AttributeError with self.assertRaises((TypeError, AttributeError)): @@ -507,7 +547,7 @@ def make(self) -> Any: del new_config -class PreservesTypeTest(SequenceLayerTest): +class PreservesTypeTest(test_utils_spec.SequenceLayerTest): def create_layer(self) -> types_spec.PreservesType: """Creates a preserves type layer.""" @@ -529,7 +569,7 @@ def test_preserves_dtype(self) -> None: self.assertEqual(layer.get_output_dtype('fake_dtype123'), 'fake_dtype123') -class PreservesShapeTest(SequenceLayerTest): +class PreservesShapeTest(test_utils_spec.SequenceLayerTest): def create_layer(self) -> types_spec.PreservesShape: """Creates a preserves shape layer.""" @@ -551,7 +591,7 @@ def test_preserves_shape(self) -> None: self.assertEqual(layer.get_output_shape((1, 2, 3, 5)), (1, 2, 3, 5)) -class StatelessTest(SequenceLayerTest): +class StatelessTest(test_utils_spec.SequenceLayerTest): def create_sequence(self) -> types_spec.Sequence: """Creates a default test sequence.""" @@ -604,7 +644,7 @@ def test_stateless_behaviors(self) -> None: mock_layer.assert_called_once_with(x, training=True, constants={'c': 1}) -class EmittingTest(SequenceLayerTest): +class EmittingTest(test_utils_spec.SequenceLayerTest): def create_sequence(self) -> types_spec.Sequence: """Creates a default test sequence.""" @@ -650,7 +690,7 @@ def test_emitting_drops_emits_on_standard_calls(self) -> None: ) -class StatelessEmittingTest(SequenceLayerTest): +class StatelessEmittingTest(test_utils_spec.SequenceLayerTest): def create_sequence(self) -> types_spec.Sequence: """Creates a default test sequence.""" @@ -702,7 +742,7 @@ def test_stateless_emitting_behaviors(self) -> None: m_layer.assert_called_once_with(x, training=False, constants=None) -class StatelessPointwiseFunctorTest(SequenceLayerTest): +class StatelessPointwiseFunctorTest(test_utils_spec.SequenceLayerTest): def create_layer( self, is_mask_required: bool From ecc7e8233536dfdb540a9be38cab0a42643d16e9 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Thu, 9 Apr 2026 21:18:27 -0700 Subject: [PATCH 3/5] chore(pylint): Silence subjective complexity warnings from pylint --- pyproject.toml | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 49a2a0b..3b9626e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,17 @@ indent-string = " " no-docstring-rgx = "^(_)?test_|^.*Test$|^__.*__$" [tool.pylint.messages_control] -disable = ["too-many-lines", "too-many-ancestors", "too-few-public-methods", "duplicate-code"] +disable = [ + "too-many-lines", + "too-many-ancestors", + "too-few-public-methods", + "duplicate-code", + "too-many-arguments", + "too-many-locals", + "too-many-statements", + "too-many-branches", + "too-many-positional-arguments", +] From 1a465302a7439ac3b55c9d17c544b11ee73c4dcf Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Thu, 9 Apr 2026 21:18:28 -0700 Subject: [PATCH 4/5] chore(jax): Resolve linting errors and add overrides in test_utils --- sequence_layers/jax/test_utils.py | 146 +++++++++++++++++++++--------- 1 file changed, 105 insertions(+), 41 deletions(-) diff --git a/sequence_layers/jax/test_utils.py b/sequence_layers/jax/test_utils.py index 274e20b..b336f15 100644 --- a/sequence_layers/jax/test_utils.py +++ b/sequence_layers/jax/test_utils.py @@ -15,15 +15,13 @@ import dataclasses import functools -import itertools import logging import random -from typing import Any, Callable, Iterable, Mapping +from typing import Any, Callable, Iterable, Mapping, override from typing import Sequence as TypingSequence from typing import TypeVar from absl.testing import absltest -from absl.testing import parameterized import chex import flax.linen as nn import jax @@ -78,7 +76,7 @@ def random_sequence( raise ValueError('Must not specify random_mask and random_lengths.') if len(dims) < 2: raise ValueError( - 'random_sequence expects at least 2 dimensions, got: %s' % (dims,) + f'random_sequence expects at least 2 dimensions, got: {dims}' ) is_complex = dtype in (np.complex64, np.complex128) @@ -326,8 +324,7 @@ def get_grad_tols( compute_dtype is None or compute_dtype == jnp.float32 ): return {'grad_rtol': 1e-5, 'grad_atol': 1e-5} - else: - return {'grad_rtol': 1e-1, 'grad_atol': 1e-1} + return {'grad_rtol': 1e-1, 'grad_atol': 1e-1} def flax_init(layer: nn.Module, *args, **kwargs): @@ -344,6 +341,7 @@ def init_layer(*args, **kwargs): def flax_apply(layer: nn.Module, params, *args, **kwargs): + """Applies a Flax module with the given parameters.""" method = kwargs.pop('method', '__call__') should_jit = kwargs.pop('jit', True) @@ -356,6 +354,7 @@ def layer_fn(params, *args, **kwargs): def sl_init(layer: types.SequenceLayer, *args, **kwargs): + """Initializes a SequenceLayer.""" training = kwargs.pop('training', False) method = kwargs.pop('method', '__call__') should_jit = kwargs.pop('jit', True) @@ -423,11 +422,10 @@ def pad_with_garbage( if isinstance(x, jax.Array): return jnp.pad(x, paddings, constant_values=pad_value) - else: - return type(x)( - jnp.pad(x.values, paddings, constant_values=pad_value), - jnp.pad(x.mask, [(1, 1), (0, 0)], constant_values=True), - ) + return type(x)( + jnp.pad(x.values, paddings, constant_values=pad_value), + jnp.pad(x.mask, [(1, 1), (0, 0)], constant_values=True), + ) return jax.tree.map( pad_with_garbage, tree, is_leaf=lambda x: isinstance(x, types.Sequence) @@ -544,8 +542,12 @@ def fn( # avoid large gradient matrices. return jax.lax.reduce_sum(jnp.abs(y), axes=list(range(2, y.ndim))) - x_real_fn = lambda x_: types.Sequence(x_ + 1j * jnp.imag(x.values), x.mask) - x_imag_fn = lambda x_: types.Sequence(jnp.real(x.values) + 1j * x_, x.mask) + def x_real_fn(x_): + return types.Sequence(x_ + 1j * jnp.imag(x.values), x.mask) + + def x_imag_fn(x_): + return types.Sequence(jnp.real(x.values) + 1j * x_, x.mask) + jac_fn_real_y = functools.partial( fn, x_fn=lambda x_: types.Sequence(x_, x.mask), y_fn=jnp.real ) @@ -690,6 +692,7 @@ def fn( def _mask_and_pad_to_max_length( a: types.Sequence, b: types.Sequence ) -> tuple[types.Sequence, types.Sequence]: + """Masks invalid timesteps and pads two sequences to the same maximum length.""" # Only compare values in non-masked regions. a = a.mask_invalid() b = b.mask_invalid() @@ -762,18 +765,24 @@ def randomize_weights_fn(variables): return layer.bind(variables) + def init_layer(self, layer, x, **kwargs): + """Initialize and bind variables for JAX.""" + key = jax.random.PRNGKey(1234) + return self.init_and_bind_layer(key, layer, x, **kwargs) + def verify_masked(self, x: types.Sequence): """Asserts all invalid timesteps in x have values masked to zero.""" # Manually mask even if x is a MaskedSequence. expected = types.Sequence(x.values, x.mask).mask_invalid() self.assertAllEqual(x.values, expected.values) + @override def verify_contract( self, l: types.SequenceLayer, x: types.Sequence, *, - training: bool, + training: bool = False, constants: types.Constants | None = None, stream_constants: bool = False, stream_constants_list: list[types.Constants] | None = None, @@ -791,6 +800,7 @@ def verify_contract( test_padding_invariance: bool = True, test_receptive_field: bool = True, test_receptive_field_relaxed: bool = False, + **kwargs, ) -> types.Sequence: """Verifies that the provided layer obeys the SequenceLayer contract. @@ -1046,6 +1056,7 @@ def _pad(x: types.Sequence, pad_back: int) -> types.Sequence: # Property 1: Check layer-wise and step-wise equivalence. self.assertSequencesClose(y_layer, y_step, rtol=rtol, atol=atol) if test_2x_step: + assert y_step_2x is not None self.assertSequencesClose(y_layer, y_step_2x, rtol=rtol, atol=atol) # Property 2: Padding invariance. @@ -1062,6 +1073,7 @@ def _pad(x: types.Sequence, pad_back: int) -> types.Sequence: # is an integer type. go/jax-integer-autodiff assert y_layer_x_grad is not None if y_layer_x_grad.dtype != jax.dtypes.float0: + assert y_step_x_grad is not None self.assertSequencesClose( y_layer_x_grad, y_step_x_grad, rtol=grad_rtol, atol=grad_atol ) @@ -1149,17 +1161,69 @@ def random_sequence( ) @override - def assertSequencesClose( # pylint: disable=invalid-name + # pyrefly: ignore[bad-override] + def _step_by_step( self, - a: types.Sequence, - b: types.Sequence, + layer: types.SequenceLayer, + x: types.Sequence, + *, + block_size: int = 1, + constants=None, + stream_constants=None, + ) -> tuple[types.Sequence, Any]: + batch = x.values.shape[0] if hasattr(x, 'values') else x.shape[0] + time = x.values.shape[1] if hasattr(x, 'values') else x.shape[1] + + input_spec = types.ShapeDType(x.channel_shape, x.dtype) + + init_constants = dict(constants) if constants else {} + if stream_constants: + init_constants.update(stream_constants) + + state = layer.get_initial_state( + batch, input_spec, constants=init_constants or None, training=False + ) + + outputs_values = [] + outputs_masks = [] + + for t in range(0, time, block_size): + x_block = types.Sequence( + x.values[:, t : t + block_size], + x.mask[:, t : t + block_size], + ) + + step_constants = dict(constants) if constants else {} + if stream_constants: + for name, seq in stream_constants.items(): + step_constants[name] = types.Sequence( + seq.values[:, t : t + block_size], + seq.mask[:, t : t + block_size], + ) + + y_block, state = layer.step( + x_block, state, constants=step_constants or None, training=False + ) + outputs_values.append(y_block.values) + outputs_masks.append(y_block.mask) + + y_values = jnp.concatenate(outputs_values, axis=1) + y_mask = jnp.concatenate(outputs_masks, axis=1) + + return types.Sequence(y_values, y_mask), state + + @override + def assertSequencesClose( # pylint: disable=arguments-differ # pyrefly: ignore[bad-override] + self, + x: types.Sequence, + y: types.Sequence, atol: float = 1e-6, rtol: float = 1e-6, ): """After padding, checks sequence values are close and masks are equal.""" - a, b = _mask_and_pad_to_max_length(a, b) - self.assertAllClose(a.values, b.values, atol=atol, rtol=rtol) - self.assertAllEqual(a.mask, b.mask) + x, y = _mask_and_pad_to_max_length(x, y) + self.assertAllClose(x.values, y.values, atol=atol, rtol=rtol) + self.assertAllEqual(x.mask, y.mask) def assertSequencesNotClose( # pylint: disable=invalid-name self, @@ -1173,15 +1237,14 @@ def assertSequencesNotClose( # pylint: disable=invalid-name self.assertNotAllClose(a.values, b.values, atol=atol, rtol=rtol) self.assertAllEqual(a.mask, b.mask) - def assertSequencesEqual( # pylint: disable=invalid-name - self, - a: types.Sequence, - b: types.Sequence, - ): + @override + def assertSequencesEqual( # pyrefly: ignore[bad-override] + self, x: types.Sequence, y: types.Sequence + ) -> None: """After padding, checks sequence values are equal and masks are equal.""" - a, b = _mask_and_pad_to_max_length(a, b) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.mask, b.mask) + x, y = _mask_and_pad_to_max_length(x, y) + self.assertAllEqual(x.values, y.values) + self.assertAllEqual(x.mask, y.mask) def assertSequencesNotEqual( # pylint: disable=invalid-name self, @@ -1193,15 +1256,16 @@ def assertSequencesNotEqual( # pylint: disable=invalid-name self.assertNotAllEqual(a.values, b.values) self.assertAllEqual(a.mask, b.mask) - def assertAllEqual(self, a, b): # pylint: disable=invalid-name + @override + def assertAllEqual(self, x, y): # pylint: disable=invalid-name """Asserts that two arrays are equal.""" - if jnp.iscomplexobj(a) or jnp.iscomplexobj(b): - a_real, a_imag = jnp.real(a), jnp.imag(a) - b_real, b_imag = jnp.real(b), jnp.imag(b) - chex.assert_trees_all_equal(a_real, b_real) - chex.assert_trees_all_equal(a_imag, b_imag) + if jnp.iscomplexobj(x) or jnp.iscomplexobj(y): + x_real, x_imag = jnp.real(x), jnp.imag(x) + y_real, y_imag = jnp.real(y), jnp.imag(y) + chex.assert_trees_all_equal(x_real, y_real) + chex.assert_trees_all_equal(x_imag, y_imag) else: - chex.assert_trees_all_equal(a, b) + chex.assert_trees_all_equal(x, y) def assertAllClose(self, a, b, atol: float = 1e-6, rtol: float = 1e-6): # pylint: disable=invalid-name """Asserts that two arrays have close values.""" @@ -1219,9 +1283,7 @@ def assertNotAllEqual(self, a, b): # pylint: disable=invalid-name chex.assert_trees_all_equal(a, b) except AssertionError: return - raise AssertionError( - 'The two values are equal at all elements. %s %s' % (a, b) - ) + raise AssertionError(f'The two values are equal at all elements. {a} {b}') def assertNotAllClose(self, a, b, atol: float = 1e-6, rtol: float = 1e-6): # pylint: disable=invalid-name """Asserts that two arrays do not have close values.""" @@ -1229,9 +1291,7 @@ def assertNotAllClose(self, a, b, atol: float = 1e-6, rtol: float = 1e-6): # py self.assertAllClose(a, b, atol=atol, rtol=rtol) except AssertionError: return - raise AssertionError( - 'The two values are close at all elements. %s %s' % (a, b) - ) + raise AssertionError(f'The two values are close at all elements. {a} {b}') class AssertConstantsLayer(types.PreservesType, types.StatelessPointwise): @@ -1239,6 +1299,8 @@ class AssertConstantsLayer(types.PreservesType, types.StatelessPointwise): @dataclasses.dataclass(frozen=True) class Config(types.SequenceLayerConfig): + """Configuration for AssertConstantsLayer.""" + expected_constant: str = 'test' name: str | None = None @@ -1293,6 +1355,8 @@ class NonSteppableLayer(types.PreservesType, types.StatelessPointwise): @dataclasses.dataclass(frozen=True) class Config(types.SequenceLayerConfig): + """Configuration for NonSteppableLayer.""" + name: str | None = None @override From 185a7e5035d9c96bafda059ea5296416a9fae397 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Fri, 10 Apr 2026 10:29:02 -0700 Subject: [PATCH 5/5] chore: Fix leftover conflict marker in types_behaviors.py --- sequence_layers/specs/types_behaviors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sequence_layers/specs/types_behaviors.py b/sequence_layers/specs/types_behaviors.py index 415da51..891bc56 100644 --- a/sequence_layers/specs/types_behaviors.py +++ b/sequence_layers/specs/types_behaviors.py @@ -34,7 +34,7 @@ def test_backend_specific_types_are_subclasses(self) -> None: issubclass(mod.SequenceLayerConfig, types_spec.SequenceLayerConfig) ) self.assertTrue(issubclass(mod.Steppable, types_spec.Steppable)) ->>>>>>> 6ad10e5 (refactor(test_utils): Abstract into spec and implementations.) + class DummyChannelSpec(NamedTuple):