keras/requirements-jax-cuda.txt
Matt Watson 149d1e1468
Bump to 2.16 nightlies for requirements (#18726)
These nightlies are compatible with the Keras 3 nightlies, so a more
friendly dev environment overall.

Tested with a quick mnist training script to verify cuda was working.
2023-11-03 13:45:30 -07:00

14 lines
409 B
Plaintext

# Tensorflow cpu-only version (needed for testing).
tf-nightly-cpu==2.16.0.dev20231103 # Pin a working nightly until rc0.
# Torch cpu-only version (needed for testing).
--extra-index-url https://download.pytorch.org/whl/cpu
torch>=2.1.0
torchvision>=0.16.0
# Jax with cuda support.
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda12_pip]
-r requirements-common.txt