# Tensorflow cpu-only version (needed for testing). tf-nightly-cpu==2.16.0.dev20240101 # Pin a working nightly until rc0. # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu121 torch==2.1.2 torchvision==0.16.2 # Jax cpu-only version (needed for testing). jax[cpu] -r requirements-common.txt