Add a requirements file for multi-backend cuda (#472)
Not totally sure if we should merge this now, or wait for tf 2.14, but figured I could put it up anyway so people could use it. With https://github.com/tensorflow/tensorflow/pull/59825 tf-nightly can be installed using cuda pip packages. Which means we can write a recipe for cross framework GPU support. To install a local development version... ```shell pip install -r requirements-cuda.txt python pip_build.py --install ``` To install the official pip version... ```shell pip install -r requirements-cuda.txt pip install keras-core --no-deps ``` Note that `--no-deps` is required to avoid pulling in `tensorflow` and `tf-nightly` at the same time. This should work in a clean python env, as long nvidia drivers are >=520.61.05. No conda or cuda shenanigans required!
This commit is contained in:
parent
c8953e5a7d
commit
59fca267a7
15
requirements-common.txt
Normal file
15
requirements-common.txt
Normal file
@ -0,0 +1,15 @@
|
||||
namex
|
||||
black>=22
|
||||
flake8
|
||||
isort
|
||||
pytest
|
||||
pandas
|
||||
absl-py
|
||||
requests
|
||||
h5py
|
||||
protobuf
|
||||
google
|
||||
tensorboard-plugin-profile
|
||||
rich
|
||||
build
|
||||
dm-tree
|
18
requirements-cuda.txt
Normal file
18
requirements-cuda.txt
Normal file
@ -0,0 +1,18 @@
|
||||
# Tensorflow.
|
||||
# Cuda via pip is only on nightly right now.
|
||||
# We will pin a known working version to avoid breakages (nightly breaks often).
|
||||
tf-nightly[and-cuda]==2.14.0.dev20230712
|
||||
|
||||
# Torch.
|
||||
# Pin the version used in colab currently (works with tf cuda version).
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
torch==2.0.1+cu118
|
||||
torchvision==0.15.2+cu118
|
||||
|
||||
# Jax.
|
||||
# Pin the version used in colab currently (works with tf cuda version).
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
jax[cuda11_pip]==0.4.10
|
||||
|
||||
# Common deps.
|
||||
-r requirements-common.txt
|
@ -1,21 +1,13 @@
|
||||
# Tensorflow.
|
||||
tensorflow
|
||||
# TODO: Use Torch CPU
|
||||
# Remove after resolving Cuda version differences with TF
|
||||
|
||||
# Torch.
|
||||
# TODO: Use Torch CPU, remove after resolving Cuda version differences with TF
|
||||
torch>=2.0.1+cpu
|
||||
torchvision>=0.15.1
|
||||
|
||||
# Jax.
|
||||
jax[cpu]
|
||||
namex
|
||||
black>=22
|
||||
flake8
|
||||
isort
|
||||
pytest
|
||||
pandas
|
||||
absl-py
|
||||
requests
|
||||
h5py
|
||||
protobuf
|
||||
google
|
||||
tensorboard-plugin-profile
|
||||
rich
|
||||
build
|
||||
dm-tree
|
||||
|
||||
# Common deps.
|
||||
-r requirements-common.txt
|
||||
|
Loading…
Reference in New Issue
Block a user