From 2d024de3a1e93f5faa0f0f3d0b0a2c7f0cbc1154 Mon Sep 17 00:00:00 2001 From: "njzjz-bot (driven by OpenClaw (model: gpt-5.5))[bot]" <48687836+njzjz-bot@users.noreply.github.com> Date: Fri, 8 May 2026 14:49:28 +0000 Subject: [PATCH] build(jax): split pinned cpu and gpu groups Keep the public jax extra unchanged and only split the pinned dependency groups used by CPU and CUDA CI. CPU uses plain jax, while GPU uses jax[cuda12]. Authored by OpenClaw (model: gpt-5.5) --- .github/workflows/test_cc.yml | 2 +- .github/workflows/test_cuda.yml | 2 +- .github/workflows/test_python.yml | 2 +- pyproject.toml | 5 ++++- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_cc.yml b/.github/workflows/test_cc.yml index 9452aa6edf..b5c9166a9b 100644 --- a/.github/workflows/test_cc.yml +++ b/.github/workflows/test_cc.yml @@ -44,7 +44,7 @@ jobs: - run: python -m pip install uv - name: Install Python dependencies run: | - source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax --torch-backend cpu + source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax_cpu --torch-backend cpu export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py mpich - name: Convert models diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 65773ccbfe..4be035ba21 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -43,7 +43,7 @@ jobs: && sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3 if: false # skip as we use nvidia image - run: python -m pip install -U uv - - run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax "jax[cuda12]" + - run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax_gpu - run: | export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 3811c5c689..461d972f57 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -31,7 +31,7 @@ jobs: source/install/uv_with_retry.sh pip install --system openmpi --group pin_tensorflow_cpu --group pin_pytorch_cpu --torch-backend cpu export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') - source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py --group pin_jax + source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py --group pin_jax_cpu source/install/uv_with_retry.sh pip install --system --find-links "https://www.paddlepaddle.org.cn/packages/nightly/cpu/paddlepaddle/" --index-url https://pypi.org/simple --trusted-host www.paddlepaddle.org.cn --trusted-host paddlepaddle.org.cn paddlepaddle==3.4.0.dev20260310 env: # Please note that uv has some issues with finding diff --git a/pyproject.toml b/pyproject.toml index 555806d904..9262945276 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -175,9 +175,12 @@ pin_pytorch_cpu = [ pin_pytorch_gpu = [ "torch==2.10.0", ] -pin_jax = [ +pin_jax_cpu = [ "jax==0.5.0;python_version>='3.10'", ] +pin_jax_gpu = [ + "jax[cuda12]==0.5.0;python_version>='3.10'", +] [tool.setuptools_scm]