From 3e9596f9a5846f33813bc83e6b074b3d8e6fa204 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 14 May 2026 00:00:44 +0800 Subject: [PATCH 1/3] CI(deps): Update jax (cpu) version to 0.10.0 in pyproject.toml Signed-off-by: Jinzhe Zeng --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9262945276..95189b1821 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,7 +176,7 @@ pin_pytorch_gpu = [ "torch==2.10.0", ] pin_jax_cpu = [ - "jax==0.5.0;python_version>='3.10'", + "jax==0.10.0;python_version>='3.10'", ] pin_jax_gpu = [ "jax[cuda12]==0.5.0;python_version>='3.10'", From f5cbd4100239131ff22a7a6323f1e83e83025e9c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 14 May 2026 00:14:54 +0800 Subject: [PATCH 2/3] Update JAX version constraints in pyproject.toml Signed-off-by: Jinzhe Zeng --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 95189b1821..127d29f5c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,7 +176,8 @@ pin_pytorch_gpu = [ "torch==2.10.0", ] pin_jax_cpu = [ - "jax==0.10.0;python_version>='3.10'", + "jax>=0.6.2;python_version>='3.10'", + "jax==0.10.0;python_version>='3.11'", ] pin_jax_gpu = [ "jax[cuda12]==0.5.0;python_version>='3.10'", From ff02c1eff1f3bc5912589af4336c2ab3b6b8e0d1 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 14 May 2026 16:58:02 +0800 Subject: [PATCH 3/3] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jinzhe Zeng --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 127d29f5c5..6c55e504b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,8 +176,8 @@ pin_pytorch_gpu = [ "torch==2.10.0", ] pin_jax_cpu = [ - "jax>=0.6.2;python_version>='3.10'", - "jax==0.10.0;python_version>='3.11'", + "jax==0.6.2; python_version=='3.10'", + "jax==0.10.0; python_version>='3.11'", ] pin_jax_gpu = [ "jax[cuda12]==0.5.0;python_version>='3.10'",