This commit is contained in:
Francois Chollet 2023-06-06 16:32:55 -07:00
parent 68178e9cb2
commit c7e8cd5f6e

@ -1,9 +1,6 @@
import sys
from keras_core import backend as backend_module
from keras_core.backend import jax as jax_backend
from keras_core.backend import tensorflow as tf_backend
from keras_core.backend import torch as torch_backend
def in_tf_graph():
@ -41,8 +38,14 @@ class DynamicBackend:
def __getattr__(self, name):
if self._backend == "tensorflow":
from keras_core.backend import tensorflow as tf_backend
return getattr(tf_backend, name)
if self._backend == "jax":
from keras_core.backend import jax as jax_backend
return getattr(jax_backend, name)
if self._backend == "torch":
from keras_core.backend import torch as torch_backend
return getattr(torch_backend, name)