2023-05-15 18:31:08 +00:00
|
|
|
from keras_core.backend.config import backend
|
|
|
|
|
2023-05-15 20:56:52 +00:00
|
|
|
if backend() == "torch":
|
|
|
|
# When using the torch backend,
|
|
|
|
# torch needs to be imported first, otherwise it will segfault
|
|
|
|
# upon import.
|
2023-05-15 18:31:08 +00:00
|
|
|
import torch
|
2023-05-15 18:07:52 +00:00
|
|
|
|
2023-05-03 22:37:13 +00:00
|
|
|
from keras_core.backend.common.keras_tensor import KerasTensor
|
|
|
|
from keras_core.backend.common.keras_tensor import any_symbolic_tensors
|
|
|
|
from keras_core.backend.common.keras_tensor import is_keras_tensor
|
|
|
|
from keras_core.backend.common.stateless_scope import StatelessScope
|
|
|
|
from keras_core.backend.common.stateless_scope import get_stateless_scope
|
|
|
|
from keras_core.backend.common.stateless_scope import in_stateless_scope
|
2023-04-27 23:02:31 +00:00
|
|
|
from keras_core.backend.common.variables import AutocastScope
|
|
|
|
from keras_core.backend.common.variables import get_autocast_scope
|
2023-04-28 21:22:29 +00:00
|
|
|
from keras_core.backend.common.variables import is_float_dtype
|
2023-04-19 01:45:30 +00:00
|
|
|
from keras_core.backend.common.variables import standardize_dtype
|
|
|
|
from keras_core.backend.common.variables import standardize_shape
|
2023-04-09 19:21:45 +00:00
|
|
|
from keras_core.backend.config import epsilon
|
2023-04-12 18:31:58 +00:00
|
|
|
from keras_core.backend.config import floatx
|
2023-04-09 19:21:45 +00:00
|
|
|
from keras_core.backend.config import image_data_format
|
|
|
|
from keras_core.backend.config import set_epsilon
|
2023-04-12 18:31:58 +00:00
|
|
|
from keras_core.backend.config import set_floatx
|
2023-04-09 19:21:45 +00:00
|
|
|
from keras_core.backend.config import set_image_data_format
|
2023-05-09 18:22:21 +00:00
|
|
|
from keras_core.backend.config import standardize_data_format
|
2023-04-12 18:31:58 +00:00
|
|
|
from keras_core.utils.io_utils import print_msg
|
2023-04-09 19:21:45 +00:00
|
|
|
|
|
|
|
# Import backend functions.
|
2023-04-19 01:45:30 +00:00
|
|
|
if backend() == "tensorflow":
|
2023-04-09 22:51:23 +00:00
|
|
|
print_msg("Using TensorFlow backend")
|
2023-04-28 00:51:32 +00:00
|
|
|
from keras_core.backend.tensorflow import * # noqa: F403
|
2023-04-19 01:45:30 +00:00
|
|
|
elif backend() == "jax":
|
2023-04-09 22:51:23 +00:00
|
|
|
print_msg("Using JAX backend.")
|
2023-04-28 00:51:32 +00:00
|
|
|
from keras_core.backend.jax import * # noqa: F403
|
2023-05-15 20:56:52 +00:00
|
|
|
elif backend() == "torch":
|
2023-05-15 18:07:52 +00:00
|
|
|
print_msg("Using PyTorch backend.")
|
2023-05-15 20:56:52 +00:00
|
|
|
from keras_core.backend.torch import * # noqa: F403
|
2023-04-09 19:21:45 +00:00
|
|
|
else:
|
2023-04-19 01:45:30 +00:00
|
|
|
raise ValueError(f"Unable to import backend : {backend()}")
|