2023-04-17 21:55:17 +00:00
|
|
|
from keras_core.api_export import keras_core_export
|
2023-04-16 19:21:29 +00:00
|
|
|
from keras_core.losses.loss import Loss
|
2023-04-25 04:21:57 +00:00
|
|
|
from keras_core.losses.losses import CategoricalHinge
|
|
|
|
from keras_core.losses.losses import Hinge
|
2023-04-16 19:21:29 +00:00
|
|
|
from keras_core.losses.losses import LossFunctionWrapper
|
2023-04-25 04:21:57 +00:00
|
|
|
from keras_core.losses.losses import MeanAbsoluteError
|
|
|
|
from keras_core.losses.losses import MeanAbsolutePercentageError
|
2023-04-17 21:55:17 +00:00
|
|
|
from keras_core.losses.losses import MeanSquaredError
|
2023-04-25 04:21:57 +00:00
|
|
|
from keras_core.losses.losses import MeanSquaredLogarithmicError
|
|
|
|
from keras_core.losses.losses import SquaredHinge
|
|
|
|
from keras_core.losses.losses import categorical_hinge
|
|
|
|
from keras_core.losses.losses import hinge
|
|
|
|
from keras_core.losses.losses import mean_absolute_error
|
|
|
|
from keras_core.losses.losses import mean_absolute_percentage_error
|
|
|
|
from keras_core.losses.losses import mean_squared_error
|
|
|
|
from keras_core.losses.losses import mean_squared_logarithmic_error
|
|
|
|
from keras_core.losses.losses import squared_hinge
|
|
|
|
from keras_core.saving import serialization_lib
|
2023-04-17 21:55:17 +00:00
|
|
|
|
2023-04-25 04:21:57 +00:00
|
|
|
ALL_OBJECTS = {
|
|
|
|
Loss,
|
|
|
|
LossFunctionWrapper,
|
|
|
|
MeanSquaredError,
|
|
|
|
MeanAbsoluteError,
|
|
|
|
MeanAbsolutePercentageError,
|
|
|
|
MeanSquaredLogarithmicError,
|
|
|
|
Hinge,
|
|
|
|
SquaredHinge,
|
|
|
|
CategoricalHinge,
|
|
|
|
mean_squared_error,
|
|
|
|
mean_absolute_error,
|
|
|
|
mean_absolute_percentage_error,
|
|
|
|
mean_squared_logarithmic_error,
|
|
|
|
hinge,
|
|
|
|
squared_hinge,
|
|
|
|
categorical_hinge,
|
|
|
|
}
|
2023-04-17 21:55:17 +00:00
|
|
|
|
2023-04-25 04:21:57 +00:00
|
|
|
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
|
2023-04-26 05:44:19 +00:00
|
|
|
ALL_OBJECTS_DICT.update(
|
|
|
|
{
|
|
|
|
"mae": mean_absolute_error,
|
|
|
|
"MAE": mean_absolute_error,
|
|
|
|
"mse": mean_squared_error,
|
|
|
|
"MSE": mean_squared_error,
|
|
|
|
}
|
|
|
|
)
|
2023-04-25 04:21:57 +00:00
|
|
|
|
|
|
|
|
|
|
|
@keras_core_export("keras_core.losses.serialize")
|
|
|
|
def serialize(loss):
|
|
|
|
"""Serializes loss function or `Loss` instance.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
loss: A Keras `Loss` instance or a loss function.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Loss configuration dictionary.
|
|
|
|
"""
|
|
|
|
return serialization_lib.serialize_keras_object(loss)
|
|
|
|
|
|
|
|
|
|
|
|
@keras_core_export("keras_core.losses.deserialize")
|
|
|
|
def deserialize(name, custom_objects=None):
|
|
|
|
"""Deserializes a serialized loss class/function instance.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: Loss configuration.
|
|
|
|
custom_objects: Optional dictionary mapping names (strings) to custom
|
|
|
|
objects (classes and functions) to be considered during
|
|
|
|
deserialization.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A Keras `Loss` instance or a loss function.
|
|
|
|
"""
|
|
|
|
return serialization_lib.deserialize_keras_object(
|
|
|
|
name,
|
|
|
|
module_objects=ALL_OBJECTS_DICT,
|
|
|
|
custom_objects=custom_objects,
|
|
|
|
)
|
2023-04-22 06:16:51 +00:00
|
|
|
|
|
|
|
|
2023-04-17 21:55:17 +00:00
|
|
|
@keras_core_export("keras_core.losses.get")
|
|
|
|
def get(identifier):
|
|
|
|
"""Retrieves a Keras loss as a `function`/`Loss` class instance.
|
|
|
|
|
|
|
|
The `identifier` may be the string name of a loss function or `Loss` class.
|
|
|
|
|
|
|
|
>>> loss = losses.get("categorical_crossentropy")
|
|
|
|
>>> type(loss)
|
|
|
|
<class 'function'>
|
|
|
|
>>> loss = losses.get("CategoricalCrossentropy")
|
|
|
|
>>> type(loss)
|
|
|
|
<class '...CategoricalCrossentropy'>
|
|
|
|
|
|
|
|
You can also specify `config` of the loss to this function by passing dict
|
|
|
|
containing `class_name` and `config` as an identifier. Also note that the
|
|
|
|
`class_name` must map to a `Loss` class
|
|
|
|
|
|
|
|
>>> identifier = {"class_name": "CategoricalCrossentropy",
|
|
|
|
... "config": {"from_logits": True}}
|
|
|
|
>>> loss = losses.get(identifier)
|
|
|
|
>>> type(loss)
|
|
|
|
<class '...CategoricalCrossentropy'>
|
|
|
|
|
|
|
|
Args:
|
|
|
|
identifier: A loss identifier. One of None or string name of a loss
|
2023-04-22 06:16:51 +00:00
|
|
|
function/class or loss configuration dictionary or a loss function
|
|
|
|
or a loss class instance.
|
2023-04-17 21:55:17 +00:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A Keras loss as a `function`/ `Loss` class instance.
|
|
|
|
"""
|
|
|
|
if identifier is None:
|
|
|
|
return None
|
|
|
|
if isinstance(identifier, str):
|
|
|
|
identifier = str(identifier)
|
|
|
|
return deserialize(identifier)
|
|
|
|
if isinstance(identifier, dict):
|
|
|
|
return deserialize(identifier)
|
|
|
|
if callable(identifier):
|
|
|
|
return identifier
|
|
|
|
raise ValueError(
|
|
|
|
f"Could not interpret loss function identifier: {identifier}"
|
|
|
|
)
|