diff --git a/keras_core/backend/torch/__init__.py b/keras_core/backend/torch/__init__.py index c7449978e..8bc9bddca 100644 --- a/keras_core/backend/torch/__init__.py +++ b/keras_core/backend/torch/__init__.py @@ -1,14 +1,17 @@ """Torch backend APIs. -Torch has a different logic of device management compared to TF and JAX. In -short variables/tensors are not by default created on GPU, and GPU cannot -directly communicate with CPU. Therefore, we are doing the following to automate -device management for Torch backend, if GPU is available: +# Note on device placement + +Torch has a different device placement style compared to TF and JAX. +In short, variables/tensors are not created on GPU by default, +and the GPU cannot directly communicate with the CPU. +To bring Torch behavior in line with TF and JAX automated device placement, +we are doing the following to automate device placement if a GPU is available: - Variables are created on GPU. - Input data will be placed on GPU at the first `keras_core.layers.Layer` call. - Tensor creation happens on GPU, e.g., `zeros()` will create a tensor on GPU. -- `convert_to_numpy` will bring the tensor to CPU and convert to numpy array. +- `convert_to_numpy` will bring the tensor to CPU before converting it to NumPy. """ from keras_core.backend.torch import core diff --git a/keras_core/operations/core.py b/keras_core/operations/core.py index e8d8efdd1..bc66306d4 100644 --- a/keras_core/operations/core.py +++ b/keras_core/operations/core.py @@ -4,6 +4,11 @@ scatter_update slice slice_update while_loop +stop_gradient +shape +cast +convert_to_tensor +convert_to_numpy """ from keras_core import backend diff --git a/keras_core/trainers/trainer.py b/keras_core/trainers/trainer.py index 75447963e..5f8c29a0e 100644 --- a/keras_core/trainers/trainer.py +++ b/keras_core/trainers/trainer.py @@ -43,7 +43,7 @@ class Trainer: else: self._compile_metrics = None if jit_compile == "auto": - if model_supports_jit(self): + if not run_eagerly and model_supports_jit(self): jit_compile = True else: jit_compile = False