diff --git a/.github/workflows/test_cc.yml b/.github/workflows/test_cc.yml index 9452aa6edf..b5c9166a9b 100644 --- a/.github/workflows/test_cc.yml +++ b/.github/workflows/test_cc.yml @@ -44,7 +44,7 @@ jobs: - run: python -m pip install uv - name: Install Python dependencies run: | - source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax --torch-backend cpu + source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax_cpu --torch-backend cpu export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py mpich - name: Convert models diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 65773ccbfe..4be035ba21 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -43,7 +43,7 @@ jobs: && sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3 if: false # skip as we use nvidia image - run: python -m pip install -U uv - - run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax "jax[cuda12]" + - run: source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_gpu --group pin_pytorch_gpu --group pin_jax_gpu - run: | export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 3811c5c689..461d972f57 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -31,7 +31,7 @@ jobs: source/install/uv_with_retry.sh pip install --system openmpi --group pin_tensorflow_cpu --group pin_pytorch_cpu --torch-backend cpu export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') - source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py --group pin_jax + source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py --group pin_jax_cpu source/install/uv_with_retry.sh pip install --system --find-links "https://www.paddlepaddle.org.cn/packages/nightly/cpu/paddlepaddle/" --index-url https://pypi.org/simple --trusted-host www.paddlepaddle.org.cn --trusted-host paddlepaddle.org.cn paddlepaddle==3.4.0.dev20260310 env: # Please note that uv has some issues with finding diff --git a/pyproject.toml b/pyproject.toml index 555806d904..9262945276 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -175,9 +175,12 @@ pin_pytorch_cpu = [ pin_pytorch_gpu = [ "torch==2.10.0", ] -pin_jax = [ +pin_jax_cpu = [ "jax==0.5.0;python_version>='3.10'", ] +pin_jax_gpu = [ + "jax[cuda12]==0.5.0;python_version>='3.10'", +] [tool.setuptools_scm]