keras/keras_core/initializers/__init__.py

115 lines
3.8 KiB
Python
Raw Normal View History

2023-04-21 17:00:32 +00:00
import inspect
from keras_core.api_export import keras_core_export
from keras_core.initializers.constant_initializers import Constant
2023-04-19 19:52:58 +00:00
from keras_core.initializers.constant_initializers import Ones
from keras_core.initializers.constant_initializers import Zeros
2023-04-09 19:21:45 +00:00
from keras_core.initializers.initializer import Initializer
2023-04-19 19:52:58 +00:00
from keras_core.initializers.random_initializers import GlorotNormal
from keras_core.initializers.random_initializers import GlorotUniform
from keras_core.initializers.random_initializers import HeNormal
from keras_core.initializers.random_initializers import HeUniform
from keras_core.initializers.random_initializers import LecunNormal
from keras_core.initializers.random_initializers import LecunUniform
2023-05-02 16:00:30 +00:00
from keras_core.initializers.random_initializers import OrthogonalInitializer
2023-04-19 19:52:58 +00:00
from keras_core.initializers.random_initializers import RandomNormal
from keras_core.initializers.random_initializers import RandomUniform
2023-04-23 03:20:56 +00:00
from keras_core.initializers.random_initializers import TruncatedNormal
2023-04-19 19:52:58 +00:00
from keras_core.initializers.random_initializers import VarianceScaling
2023-04-21 17:00:32 +00:00
from keras_core.saving import serialization_lib
from keras_core.utils.naming import to_snake_case
ALL_OBJECTS = {
Initializer,
Constant,
Ones,
Zeros,
GlorotNormal,
GlorotUniform,
HeNormal,
HeUniform,
LecunNormal,
LecunUniform,
RandomNormal,
2023-04-23 03:20:56 +00:00
TruncatedNormal,
2023-04-21 17:00:32 +00:00
RandomUniform,
VarianceScaling,
2023-05-02 16:00:30 +00:00
OrthogonalInitializer,
2023-04-21 17:00:32 +00:00
}
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
ALL_OBJECTS_DICT.update(
{to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS}
)
2023-04-27 03:22:03 +00:00
# Aliases
ALL_OBJECTS_DICT.update(
{
"uniform": RandomUniform,
"normal": RandomNormal,
2023-05-02 16:00:30 +00:00
"orthogonal": OrthogonalInitializer,
2023-04-27 03:22:03 +00:00
}
)
2023-04-21 17:00:32 +00:00
2023-04-19 19:52:58 +00:00
2023-04-21 17:00:32 +00:00
@keras_core_export("keras_core.initializers.serialize")
def serialize(initializer):
2023-04-25 19:59:32 +00:00
"""Returns the initializer configuration as a Python dict."""
2023-04-21 17:00:32 +00:00
return serialization_lib.serialize_keras_object(initializer)
2023-04-19 19:52:58 +00:00
2023-04-21 17:00:32 +00:00
@keras_core_export("keras_core.initializers.deserialize")
def deserialize(config, custom_objects=None):
2023-04-25 19:59:32 +00:00
"""Returns a Keras initializer object via its configuration."""
2023-04-21 17:00:32 +00:00
return serialization_lib.deserialize_keras_object(
config,
module_objects=ALL_OBJECTS_DICT,
custom_objects=custom_objects,
)
@keras_core_export("keras_core.initializers.get")
2023-04-19 19:52:58 +00:00
def get(identifier):
2023-04-25 19:59:32 +00:00
"""Retrieves a Keras initializer object via an identifier.
2023-04-21 17:00:32 +00:00
The `identifier` may be the string name of a initializers function or class
(case-sensitively).
>>> identifier = 'Ones'
>>> keras_core.initializers.deserialize(identifier)
<...keras_core.initializers.initializers.Ones...>
You can also specify `config` of the initializer to this function by passing
dict containing `class_name` and `config` as an identifier. Also note that
the `class_name` must map to a `Initializer` class.
>>> cfg = {'class_name': 'Ones', 'config': {}}
>>> keras_core.initializers.deserialize(cfg)
<...keras_core.initializers.initializers.Ones...>
In the case that the `identifier` is a class, this method will return a new
instance of the class by its constructor.
Args:
identifier: String or dict that contains the initializer name or
configurations.
Returns:
Initializer instance base on the input identifier.
"""
if identifier is None:
return None
if isinstance(identifier, dict):
2023-04-27 03:22:03 +00:00
identifier = deserialize(identifier)
2023-04-21 17:00:32 +00:00
elif isinstance(identifier, str):
config = {"class_name": str(identifier), "config": {}}
2023-04-27 03:22:03 +00:00
identifier = deserialize(config)
if callable(identifier):
2023-04-21 17:00:32 +00:00
if inspect.isclass(identifier):
identifier = identifier()
return identifier
2023-04-27 03:22:03 +00:00
raise ValueError(
f"Could not interpret initializer object identifier: {identifier}"
)