diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 597b63906..7998d6734 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -7,7 +7,8 @@ torch>=2.1.0 torchvision>=0.16.0 # Jax with cuda support. +# TODO: 0.4.24 has an updated Cuda version breaks Jax CI. --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -jax[cuda12_pip] +jax[cuda12_pip]==0.4.23 -r requirements-common.txt