Fix JAX CI Cuda error by fixing JAX version (#19161)

This commit is contained in:
Ramesh Sampath 2024-02-08 18:32:32 -06:00 committed by GitHub
parent 8d58b790e4
commit 830bea69e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -7,7 +7,8 @@ torch>=2.1.0
torchvision>=0.16.0
# Jax with cuda support.
# TODO: 0.4.24 has an updated Cuda version breaks Jax CI.
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12_pip]
jax[cuda12_pip]==0.4.23
-r requirements-common.txt