diff --git a/pyproject.toml b/pyproject.toml index 9262945276..b4c5e523fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,10 +176,10 @@ pin_pytorch_gpu = [ "torch==2.10.0", ] pin_jax_cpu = [ - "jax==0.5.0;python_version>='3.10'", + "jax==0.6.2;python_version>='3.10'", ] pin_jax_gpu = [ - "jax[cuda12]==0.5.0;python_version>='3.10'", + "jax[cuda12]==0.6.2;python_version>='3.10'", ] [tool.setuptools_scm]