keras/keras_core/backend/__init__.py
Chen Qian f6df67f2d2 Add numpy module in jax/ and tensorflow/ (#13)
* Add jax/numpy and tensorflow/numpy

* refactor code

* more

* even better
2023-04-18 18:45:30 -07:00

30 lines
1.2 KiB
Python

import json
import os
from keras_core.backend.common.variables import standardize_dtype
from keras_core.backend.common.variables import standardize_shape
from keras_core.backend.config import backend
from keras_core.backend.config import epsilon
from keras_core.backend.config import floatx
from keras_core.backend.config import image_data_format
from keras_core.backend.config import set_epsilon
from keras_core.backend.config import set_floatx
from keras_core.backend.config import set_image_data_format
from keras_core.backend.keras_tensor import KerasTensor
from keras_core.backend.keras_tensor import any_symbolic_tensors
from keras_core.backend.keras_tensor import is_keras_tensor
from keras_core.backend.stateless_scope import StatelessScope
from keras_core.backend.stateless_scope import get_stateless_scope
from keras_core.backend.stateless_scope import in_stateless_scope
from keras_core.utils.io_utils import print_msg
# Import backend functions.
if backend() == "tensorflow":
print_msg("Using TensorFlow backend")
from keras_core.backend.tensorflow import *
elif backend() == "jax":
print_msg("Using JAX backend.")
from keras_core.backend.jax import *
else:
raise ValueError(f"Unable to import backend : {backend()}")