Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,17 @@ indent-string = " "
no-docstring-rgx = "^(_)?test_|^.*Test$|^__.*__$"

[tool.pylint.messages_control]
disable = ["too-many-lines", "too-many-ancestors", "too-few-public-methods", "duplicate-code"]
disable = [
"too-many-lines",
"too-many-ancestors",
"too-few-public-methods",
"duplicate-code",
"too-many-arguments",
"too-many-locals",
"too-many-statements",
"too-many-branches",
"too-many-positional-arguments",
]



Expand Down
9 changes: 6 additions & 3 deletions sequence_layers/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
# limitations under the License.
"""Sequence layers in JAX."""

# (re-export the names for typechecking)
from . import backend as backend
from . import types as types
from . import test_utils as test_utils
from .test_utils import SequenceLayerTest
# pylint: disable=wildcard-import
from sequence_layers.jax.attention import *
from sequence_layers.jax.combinators import *
Expand All @@ -28,6 +33,4 @@
from sequence_layers.jax.time_varying import *
from sequence_layers.jax.types import *

# (re-export the names for typechecking)
from . import backend as backend
from . import types as types

4 changes: 4 additions & 0 deletions sequence_layers/jax/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class BackendWrapper(spec.xp):

bool_ = jnp.bool_
int32 = jnp.int32
float32 = jnp.float32

@override
def array(self, a, dtype=None) -> types_spec.Array:
Expand All @@ -22,5 +23,8 @@ def array(self, a, dtype=None) -> types_spec.Array:
def zeros(self, shape, dtype=None) -> types_spec.Array:
return jnp.zeros(shape, dtype=dtype)

def concatenate(self, arrays, axis=0) -> types_spec.Array:
return jnp.concatenate(arrays, axis=axis)


xp: spec.xp = BackendWrapper()
14 changes: 14 additions & 0 deletions sequence_layers/jax/backend_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Tests for JAX backend utilities."""

from absl.testing import absltest

from sequence_layers.jax import test_utils
from sequence_layers.specs import backend_behaviors as spec


class ModuleSpecTest(test_utils.SequenceLayerTest, spec.ModuleSpecTest):
pass


if __name__ == '__main__':
absltest.main()
Loading