Add a requirements file for multi-backend cuda (#472)

Not totally sure if we should merge this now, or wait for tf 2.14, but
figured I could put it up anyway so people could use it. With
https://github.com/tensorflow/tensorflow/pull/59825
tf-nightly can be installed using cuda pip packages. Which means we
can write a recipe for cross framework GPU support.

To install a local development version...
```shell
pip install -r requirements-cuda.txt
python pip_build.py --install
```

To install the official pip version...
```shell
pip install -r requirements-cuda.txt
pip install keras-core --no-deps
```

Note that `--no-deps` is required to avoid pulling in `tensorflow` and
`tf-nightly` at the same time.

This should work in a clean python env, as long nvidia drivers are
>=520.61.05. No conda or cuda shenanigans required!
This commit is contained in:
Matt Watson 2023-07-16 10:20:10 -07:00 committed by Francois Chollet
parent c8953e5a7d
commit 59fca267a7
3 changed files with 42 additions and 17 deletions

15
requirements-common.txt Normal file

@ -0,0 +1,15 @@
namex
black>=22
flake8
isort
pytest
pandas
absl-py
requests
h5py
protobuf
google
tensorboard-plugin-profile
rich
build
dm-tree

18
requirements-cuda.txt Normal file

@ -0,0 +1,18 @@
# Tensorflow.
# Cuda via pip is only on nightly right now.
# We will pin a known working version to avoid breakages (nightly breaks often).
tf-nightly[and-cuda]==2.14.0.dev20230712
# Torch.
# Pin the version used in colab currently (works with tf cuda version).
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.0.1+cu118
torchvision==0.15.2+cu118
# Jax.
# Pin the version used in colab currently (works with tf cuda version).
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax[cuda11_pip]==0.4.10
# Common deps.
-r requirements-common.txt

@ -1,21 +1,13 @@
# Tensorflow.
tensorflow
# TODO: Use Torch CPU
# Remove after resolving Cuda version differences with TF
# Torch.
# TODO: Use Torch CPU, remove after resolving Cuda version differences with TF
torch>=2.0.1+cpu
torchvision>=0.15.1
# Jax.
jax[cpu]
namex
black>=22
flake8
isort
pytest
pandas
absl-py
requests
h5py
protobuf
google
tensorboard-plugin-profile
rich
build
dm-tree
# Common deps.
-r requirements-common.txt