2d40cb20b9
* Adds unit normalization and tests * Adds layer normalization and initial tests * Fixes formatting in docstrings * Fix type issues for JAX * Fix nits * Initial stash for group_normalization and spectral_normalization * Adds spectral normalization and tests * Adds group normalization and tests * Formatting fixes * Fix small nit in docstring * Fix docstring and tests * Adds RandomContrast and associated tests * Remove arithmetic comment * Adds RandomBrightness and tests * Fix docstring and format * Fix nits and add backend generator * Inlines random_contrast helper * Add bincount op * Add CategoryEncoding layer and tests * Fix formatting * Fix JAX issues * Fix JAX bincount * Formatting and small fix * Fix nits and docstrings * Add args to bincount op test
34 lines
1.5 KiB
Python
34 lines
1.5 KiB
Python
import json
|
|
import os
|
|
|
|
from keras_core.backend.common.keras_tensor import KerasTensor
|
|
from keras_core.backend.common.keras_tensor import any_symbolic_tensors
|
|
from keras_core.backend.common.keras_tensor import is_keras_tensor
|
|
from keras_core.backend.common.stateless_scope import StatelessScope
|
|
from keras_core.backend.common.stateless_scope import get_stateless_scope
|
|
from keras_core.backend.common.stateless_scope import in_stateless_scope
|
|
from keras_core.backend.common.variables import AutocastScope
|
|
from keras_core.backend.common.variables import get_autocast_scope
|
|
from keras_core.backend.common.variables import is_float_dtype
|
|
from keras_core.backend.common.variables import standardize_dtype
|
|
from keras_core.backend.common.variables import standardize_shape
|
|
from keras_core.backend.config import backend
|
|
from keras_core.backend.config import epsilon
|
|
from keras_core.backend.config import floatx
|
|
from keras_core.backend.config import image_data_format
|
|
from keras_core.backend.config import set_epsilon
|
|
from keras_core.backend.config import set_floatx
|
|
from keras_core.backend.config import set_image_data_format
|
|
from keras_core.backend.config import standardize_data_format
|
|
from keras_core.utils.io_utils import print_msg
|
|
|
|
# Import backend functions.
|
|
if backend() == "tensorflow":
|
|
print_msg("Using TensorFlow backend")
|
|
from keras_core.backend.tensorflow import * # noqa: F403
|
|
elif backend() == "jax":
|
|
print_msg("Using JAX backend.")
|
|
from keras_core.backend.jax import * # noqa: F403
|
|
else:
|
|
raise ValueError(f"Unable to import backend : {backend()}")
|