keras/keras_core/losses/__init__.py

126 lines
4.0 KiB
Python
Raw Normal View History

2023-04-17 21:55:17 +00:00
from keras_core.api_export import keras_core_export
2023-04-16 19:21:29 +00:00
from keras_core.losses.loss import Loss
2023-04-25 04:21:57 +00:00
from keras_core.losses.losses import CategoricalHinge
from keras_core.losses.losses import Hinge
2023-04-16 19:21:29 +00:00
from keras_core.losses.losses import LossFunctionWrapper
2023-04-25 04:21:57 +00:00
from keras_core.losses.losses import MeanAbsoluteError
from keras_core.losses.losses import MeanAbsolutePercentageError
2023-04-17 21:55:17 +00:00
from keras_core.losses.losses import MeanSquaredError
2023-04-25 04:21:57 +00:00
from keras_core.losses.losses import MeanSquaredLogarithmicError
from keras_core.losses.losses import SquaredHinge
from keras_core.losses.losses import categorical_hinge
from keras_core.losses.losses import hinge
from keras_core.losses.losses import mean_absolute_error
from keras_core.losses.losses import mean_absolute_percentage_error
from keras_core.losses.losses import mean_squared_error
from keras_core.losses.losses import mean_squared_logarithmic_error
from keras_core.losses.losses import squared_hinge
from keras_core.saving import serialization_lib
2023-04-17 21:55:17 +00:00
2023-04-25 04:21:57 +00:00
ALL_OBJECTS = {
Loss,
LossFunctionWrapper,
MeanSquaredError,
MeanAbsoluteError,
MeanAbsolutePercentageError,
MeanSquaredLogarithmicError,
Hinge,
SquaredHinge,
CategoricalHinge,
mean_squared_error,
mean_absolute_error,
mean_absolute_percentage_error,
mean_squared_logarithmic_error,
hinge,
squared_hinge,
categorical_hinge,
}
2023-04-17 21:55:17 +00:00
2023-04-25 04:21:57 +00:00
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
2023-04-26 05:44:19 +00:00
ALL_OBJECTS_DICT.update(
{
"mae": mean_absolute_error,
"MAE": mean_absolute_error,
"mse": mean_squared_error,
"MSE": mean_squared_error,
}
)
2023-04-25 04:21:57 +00:00
@keras_core_export("keras_core.losses.serialize")
def serialize(loss):
"""Serializes loss function or `Loss` instance.
Args:
loss: A Keras `Loss` instance or a loss function.
Returns:
Loss configuration dictionary.
"""
return serialization_lib.serialize_keras_object(loss)
@keras_core_export("keras_core.losses.deserialize")
def deserialize(name, custom_objects=None):
"""Deserializes a serialized loss class/function instance.
Args:
name: Loss configuration.
custom_objects: Optional dictionary mapping names (strings) to custom
objects (classes and functions) to be considered during
deserialization.
Returns:
A Keras `Loss` instance or a loss function.
"""
return serialization_lib.deserialize_keras_object(
name,
module_objects=ALL_OBJECTS_DICT,
custom_objects=custom_objects,
)
2023-04-17 21:55:17 +00:00
@keras_core_export("keras_core.losses.get")
def get(identifier):
"""Retrieves a Keras loss as a `function`/`Loss` class instance.
The `identifier` may be the string name of a loss function or `Loss` class.
>>> loss = losses.get("categorical_crossentropy")
>>> type(loss)
<class 'function'>
>>> loss = losses.get("CategoricalCrossentropy")
>>> type(loss)
<class '...CategoricalCrossentropy'>
You can also specify `config` of the loss to this function by passing dict
containing `class_name` and `config` as an identifier. Also note that the
`class_name` must map to a `Loss` class
>>> identifier = {"class_name": "CategoricalCrossentropy",
... "config": {"from_logits": True}}
>>> loss = losses.get(identifier)
>>> type(loss)
<class '...CategoricalCrossentropy'>
Args:
identifier: A loss identifier. One of None or string name of a loss
function/class or loss configuration dictionary or a loss function
or a loss class instance.
2023-04-17 21:55:17 +00:00
Returns:
A Keras loss as a `function`/ `Loss` class instance.
"""
if identifier is None:
return None
if isinstance(identifier, str):
identifier = str(identifier)
return deserialize(identifier)
if isinstance(identifier, dict):
return deserialize(identifier)
if callable(identifier):
return identifier
raise ValueError(
f"Could not interpret loss function identifier: {identifier}"
)