diff --git a/pyproject.toml b/pyproject.toml index 9262945276..6c55e504b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,7 +176,8 @@ pin_pytorch_gpu = [ "torch==2.10.0", ] pin_jax_cpu = [ - "jax==0.5.0;python_version>='3.10'", + "jax==0.6.2; python_version=='3.10'", + "jax==0.10.0; python_version>='3.11'", ] pin_jax_gpu = [ "jax[cuda12]==0.5.0;python_version>='3.10'",