This commit is contained in:
Francois Chollet 2023-06-20 15:16:03 -07:00
parent e2e7f0d061
commit 57d8ac26f1
3 changed files with 14 additions and 6 deletions

@ -1,14 +1,17 @@
"""Torch backend APIs.
Torch has a different logic of device management compared to TF and JAX. In
short variables/tensors are not by default created on GPU, and GPU cannot
directly communicate with CPU. Therefore, we are doing the following to automate
device management for Torch backend, if GPU is available:
# Note on device placement
Torch has a different device placement style compared to TF and JAX.
In short, variables/tensors are not created on GPU by default,
and the GPU cannot directly communicate with the CPU.
To bring Torch behavior in line with TF and JAX automated device placement,
we are doing the following to automate device placement if a GPU is available:
- Variables are created on GPU.
- Input data will be placed on GPU at the first `keras_core.layers.Layer` call.
- Tensor creation happens on GPU, e.g., `zeros()` will create a tensor on GPU.
- `convert_to_numpy` will bring the tensor to CPU and convert to numpy array.
- `convert_to_numpy` will bring the tensor to CPU before converting it to NumPy.
"""
from keras_core.backend.torch import core

@ -4,6 +4,11 @@ scatter_update
slice
slice_update
while_loop
stop_gradient
shape
cast
convert_to_tensor
convert_to_numpy
"""
from keras_core import backend

@ -43,7 +43,7 @@ class Trainer:
else:
self._compile_metrics = None
if jit_compile == "auto":
if model_supports_jit(self):
if not run_eagerly and model_supports_jit(self):
jit_compile = True
else:
jit_compile = False