keras/keras_core/activations/__init__.py

98 lines
3.3 KiB
Python
Raw Normal View History

2023-04-22 21:46:39 +00:00
import types
2023-04-21 22:01:17 +00:00
2023-04-23 01:03:15 +00:00
from keras_core.activations.activations import elu
from keras_core.activations.activations import exponential
from keras_core.activations.activations import gelu
from keras_core.activations.activations import hard_sigmoid
2023-04-22 21:46:39 +00:00
from keras_core.activations.activations import leaky_relu
2023-04-23 01:03:15 +00:00
from keras_core.activations.activations import linear
from keras_core.activations.activations import log_softmax
from keras_core.activations.activations import mish
from keras_core.activations.activations import relu
2023-04-22 21:46:39 +00:00
from keras_core.activations.activations import relu6
from keras_core.activations.activations import selu
2023-04-23 01:03:15 +00:00
from keras_core.activations.activations import sigmoid
from keras_core.activations.activations import silu
from keras_core.activations.activations import softmax
2023-04-22 21:46:39 +00:00
from keras_core.activations.activations import softplus
from keras_core.activations.activations import softsign
from keras_core.activations.activations import tanh
2023-04-23 01:03:15 +00:00
from keras_core.api_export import keras_core_export
from keras_core.saving import object_registration
from keras_core.saving import serialization_lib
2023-04-21 22:01:17 +00:00
2023-04-22 21:46:39 +00:00
ALL_OBJECTS = {
relu,
leaky_relu,
relu6,
softmax,
elu,
selu,
softplus,
softsign,
silu,
gelu,
tanh,
sigmoid,
exponential,
hard_sigmoid,
linear,
mish,
log_softmax,
}
2023-04-23 01:03:15 +00:00
2023-04-22 21:46:39 +00:00
ALL_OBJECTS_DICT = {fn.__name__: fn for fn in ALL_OBJECTS}
2023-04-12 21:27:30 +00:00
2023-04-22 21:46:39 +00:00
@keras_core_export("keras_core.activations.serialize")
def serialize(activation):
fn_config = serialization_lib.serialize_keras_object(activation)
if "config" not in fn_config:
raise ValueError(
f"Unknown activation function '{activation}' cannot be "
"serialized due to invalid function name. Make sure to use "
"an activation name that matches the references defined in "
2023-04-24 00:00:19 +00:00
"activations.py or use "
"`@keras_core.saving.register_keras_serializable()`"
2023-04-22 21:46:39 +00:00
"to register any custom activations. "
f"config={fn_config}"
)
if not isinstance(activation, types.FunctionType):
# Case for additional custom activations represented by objects
return fn_config
if (
isinstance(fn_config["config"], str)
and fn_config["config"] not in globals()
):
# Case for custom activation functions from external activations modules
fn_config["config"] = object_registration.get_registered_name(
activation
)
return fn_config
# Case for keras.activations builtins (simply return name)
return fn_config["config"]
2023-04-22 21:46:39 +00:00
@keras_core_export("keras_core.activations.deserialize")
def deserialize(config, custom_objects=None):
"""Return a Keras activation function via its config."""
return serialization_lib.deserialize_keras_object(
config,
module_objects=ALL_OBJECTS_DICT,
custom_objects=custom_objects,
)
@keras_core_export("keras_core.activations.get")
def get(identifier):
"""Retrieve a Keras activation function via an identifier."""
if identifier is None:
return linear
if isinstance(identifier, (str, dict)):
return deserialize(identifier)
elif callable(identifier):
return identifier
raise TypeError(
f"Could not interpret activation function identifier: {identifier}"
)