94 lines
2.8 KiB
Python
94 lines
2.8 KiB
Python
from keras_core.api_export import keras_core_export
|
|
from keras_core.optimizers.adadelta import Adadelta
|
|
from keras_core.optimizers.adafactor import Adafactor
|
|
from keras_core.optimizers.adagrad import Adagrad
|
|
from keras_core.optimizers.adam import Adam
|
|
from keras_core.optimizers.adamax import Adamax
|
|
from keras_core.optimizers.adamw import AdamW
|
|
from keras_core.optimizers.ftrl import Ftrl
|
|
from keras_core.optimizers.nadam import Nadam
|
|
from keras_core.optimizers.optimizer import Optimizer
|
|
from keras_core.optimizers.rmsprop import RMSprop
|
|
from keras_core.optimizers.sgd import SGD
|
|
from keras_core.saving import serialization_lib
|
|
|
|
ALL_OBJECTS = {
|
|
Optimizer,
|
|
Adam,
|
|
SGD,
|
|
RMSprop,
|
|
Adadelta,
|
|
AdamW,
|
|
Adagrad,
|
|
Adamax,
|
|
Adafactor,
|
|
Nadam,
|
|
Ftrl,
|
|
}
|
|
ALL_OBJECTS_DICT = {cls.__name__.lower(): cls for cls in ALL_OBJECTS}
|
|
|
|
|
|
@keras_core_export("keras_core.optimizers.serialize")
|
|
def serialize(optimizer):
|
|
"""Returns the optimizer configuration as a Python dict.
|
|
|
|
Args:
|
|
optimizer: An `Optimizer` instance to serialize.
|
|
|
|
Returns:
|
|
Python dict which contains the configuration of the optimizer.
|
|
"""
|
|
return serialization_lib.serialize_keras_object(optimizer)
|
|
|
|
|
|
@keras_core_export("keras_core.optimizers.deserialize")
|
|
def deserialize(config, custom_objects=None):
|
|
"""Returns a Keras optimizer object via its configuration.
|
|
|
|
Args:
|
|
config: Optimizer configuration dictionary.
|
|
custom_objects: Optional dictionary mapping names (strings) to custom
|
|
objects (classes and functions) to be considered during
|
|
deserialization.
|
|
|
|
Returns:
|
|
A Keras Optimizer instance.
|
|
"""
|
|
# Make deserialization case-insensitive for built-in optimizers.
|
|
if config["class_name"].lower() in ALL_OBJECTS_DICT:
|
|
config["class_name"] = config["class_name"].lower()
|
|
|
|
return serialization_lib.deserialize_keras_object(
|
|
config,
|
|
module_objects=ALL_OBJECTS_DICT,
|
|
custom_objects=custom_objects,
|
|
)
|
|
|
|
|
|
@keras_core_export("keras_core.optimizers.get")
|
|
def get(identifier):
|
|
"""Retrieves a Keras Optimizer instance.
|
|
|
|
Args:
|
|
identifier: Optimizer identifier, one of:
|
|
- String: name of an optimizer
|
|
- Dictionary: configuration dictionary.
|
|
- Keras Optimizer instance (it will be returned unchanged).
|
|
|
|
Returns:
|
|
A Keras Optimizer instance.
|
|
"""
|
|
if identifier is None:
|
|
return None
|
|
elif isinstance(identifier, dict):
|
|
obj = deserialize(identifier)
|
|
elif isinstance(identifier, str):
|
|
config = {"class_name": identifier, "config": {}}
|
|
obj = deserialize(config)
|
|
else:
|
|
obj = identifier
|
|
|
|
if isinstance(obj, Optimizer):
|
|
return obj
|
|
raise ValueError(f"Could not interpret optimizer identifier: {identifier}")
|