keras/keras_core/backend/__init__.py
Neel Kovelamudi 2d40cb20b9 Adds CategoryEncoding layer, bincount op, and tests (#161)
* Adds unit normalization and tests

* Adds layer normalization and initial tests

* Fixes formatting in docstrings

* Fix type issues for JAX

* Fix nits

* Initial stash for group_normalization and spectral_normalization

* Adds spectral normalization and tests

* Adds group normalization and tests

* Formatting fixes

* Fix small nit in docstring

* Fix docstring and tests

* Adds RandomContrast and associated tests

* Remove arithmetic comment

* Adds RandomBrightness and tests

* Fix docstring and format

* Fix nits and add backend generator

* Inlines random_contrast helper

* Add bincount op

* Add CategoryEncoding layer and tests

* Fix formatting

* Fix JAX issues

* Fix JAX bincount

* Formatting and small fix

* Fix nits and docstrings

* Add args to bincount op test
2023-05-14 00:07:43 +00:00

34 lines
1.5 KiB
Python

import json
import os
from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.backend.common.keras_tensor import any_symbolic_tensors
from keras_core.backend.common.keras_tensor import is_keras_tensor
from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.backend.common.stateless_scope import get_stateless_scope
from keras_core.backend.common.stateless_scope import in_stateless_scope
from keras_core.backend.common.variables import AutocastScope
from keras_core.backend.common.variables import get_autocast_scope
from keras_core.backend.common.variables import is_float_dtype
from keras_core.backend.common.variables import standardize_dtype
from keras_core.backend.common.variables import standardize_shape
from keras_core.backend.config import backend
from keras_core.backend.config import epsilon
from keras_core.backend.config import floatx
from keras_core.backend.config import image_data_format
from keras_core.backend.config import set_epsilon
from keras_core.backend.config import set_floatx
from keras_core.backend.config import set_image_data_format
from keras_core.backend.config import standardize_data_format
from keras_core.utils.io_utils import print_msg
# Import backend functions.
if backend() == "tensorflow":
print_msg("Using TensorFlow backend")
from keras_core.backend.tensorflow import * # noqa: F403
elif backend() == "jax":
print_msg("Using JAX backend.")
from keras_core.backend.jax import * # noqa: F403
else:
raise ValueError(f"Unable to import backend : {backend()}")