2023-06-06 20:58:26 +00:00
|
|
|
import sys
|
|
|
|
|
|
|
|
from keras_core import backend as backend_module
|
|
|
|
|
|
|
|
|
|
|
|
def in_tf_graph():
|
|
|
|
if "tensorflow" in sys.modules:
|
2023-07-14 20:00:05 +00:00
|
|
|
from keras_core.utils.module_utils import tensorflow as tf
|
2023-06-06 20:58:26 +00:00
|
|
|
|
|
|
|
return not tf.executing_eagerly()
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
class DynamicBackend:
|
|
|
|
"""A class that can be used to switch from one backend to another.
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
|
|
|
|
```python
|
|
|
|
backend = DynamicBackend("tensorflow")
|
|
|
|
y = backend.square(tf.constant(...))
|
|
|
|
backend.set_backend("jax")
|
|
|
|
y = backend.square(jax.numpy.array(...))
|
|
|
|
```
|
|
|
|
|
|
|
|
Args:
|
|
|
|
backend: Initial backend to use (string).
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, backend=None):
|
|
|
|
self._backend = backend or backend_module.backend()
|
|
|
|
|
|
|
|
def set_backend(self, backend):
|
|
|
|
self._backend = backend
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self._backend = backend_module.backend()
|
|
|
|
|
|
|
|
def __getattr__(self, name):
|
|
|
|
if self._backend == "tensorflow":
|
2023-06-06 23:32:55 +00:00
|
|
|
from keras_core.backend import tensorflow as tf_backend
|
|
|
|
|
2023-06-06 20:58:26 +00:00
|
|
|
return getattr(tf_backend, name)
|
|
|
|
if self._backend == "jax":
|
2023-06-06 23:32:55 +00:00
|
|
|
from keras_core.backend import jax as jax_backend
|
|
|
|
|
2023-06-06 20:58:26 +00:00
|
|
|
return getattr(jax_backend, name)
|
|
|
|
if self._backend == "torch":
|
2023-06-06 23:32:55 +00:00
|
|
|
from keras_core.backend import torch as torch_backend
|
|
|
|
|
2023-06-06 20:58:26 +00:00
|
|
|
return getattr(torch_backend, name)
|
2023-07-18 19:38:48 +00:00
|
|
|
if self._backend == "numpy":
|
|
|
|
# TODO (ariG23498):
|
|
|
|
# The import `from keras_core.backend import numpy as numpy_backend`
|
|
|
|
# is not working. This is a temporary fix.
|
|
|
|
# The import is redirected to `keras_core.backend.numpy.numpy.py`
|
|
|
|
from keras_core import backend as numpy_backend
|
|
|
|
|
|
|
|
return getattr(numpy_backend, name)
|