Nits.
This commit is contained in:
parent
e2e7f0d061
commit
57d8ac26f1
@ -1,14 +1,17 @@
|
||||
"""Torch backend APIs.
|
||||
|
||||
Torch has a different logic of device management compared to TF and JAX. In
|
||||
short variables/tensors are not by default created on GPU, and GPU cannot
|
||||
directly communicate with CPU. Therefore, we are doing the following to automate
|
||||
device management for Torch backend, if GPU is available:
|
||||
# Note on device placement
|
||||
|
||||
Torch has a different device placement style compared to TF and JAX.
|
||||
In short, variables/tensors are not created on GPU by default,
|
||||
and the GPU cannot directly communicate with the CPU.
|
||||
To bring Torch behavior in line with TF and JAX automated device placement,
|
||||
we are doing the following to automate device placement if a GPU is available:
|
||||
|
||||
- Variables are created on GPU.
|
||||
- Input data will be placed on GPU at the first `keras_core.layers.Layer` call.
|
||||
- Tensor creation happens on GPU, e.g., `zeros()` will create a tensor on GPU.
|
||||
- `convert_to_numpy` will bring the tensor to CPU and convert to numpy array.
|
||||
- `convert_to_numpy` will bring the tensor to CPU before converting it to NumPy.
|
||||
"""
|
||||
|
||||
from keras_core.backend.torch import core
|
||||
|
@ -4,6 +4,11 @@ scatter_update
|
||||
slice
|
||||
slice_update
|
||||
while_loop
|
||||
stop_gradient
|
||||
shape
|
||||
cast
|
||||
convert_to_tensor
|
||||
convert_to_numpy
|
||||
"""
|
||||
|
||||
from keras_core import backend
|
||||
|
@ -43,7 +43,7 @@ class Trainer:
|
||||
else:
|
||||
self._compile_metrics = None
|
||||
if jit_compile == "auto":
|
||||
if model_supports_jit(self):
|
||||
if not run_eagerly and model_supports_jit(self):
|
||||
jit_compile = True
|
||||
else:
|
||||
jit_compile = False
|
||||
|
Loading…
Reference in New Issue
Block a user