diff --git a/keras_core/losses/__init__.py b/keras_core/losses/__init__.py index ae5aa512c..c8709c7fe 100644 --- a/keras_core/losses/__init__.py +++ b/keras_core/losses/__init__.py @@ -1,11 +1,75 @@ from keras_core.api_export import keras_core_export from keras_core.losses.loss import Loss +from keras_core.losses.losses import CategoricalHinge +from keras_core.losses.losses import Hinge from keras_core.losses.losses import LossFunctionWrapper +from keras_core.losses.losses import MeanAbsoluteError +from keras_core.losses.losses import MeanAbsolutePercentageError from keras_core.losses.losses import MeanSquaredError +from keras_core.losses.losses import MeanSquaredLogarithmicError +from keras_core.losses.losses import SquaredHinge +from keras_core.losses.losses import categorical_hinge +from keras_core.losses.losses import hinge +from keras_core.losses.losses import mean_absolute_error +from keras_core.losses.losses import mean_absolute_percentage_error +from keras_core.losses.losses import mean_squared_error +from keras_core.losses.losses import mean_squared_logarithmic_error +from keras_core.losses.losses import squared_hinge +from keras_core.saving import serialization_lib + +ALL_OBJECTS = { + Loss, + LossFunctionWrapper, + MeanSquaredError, + MeanAbsoluteError, + MeanAbsolutePercentageError, + MeanSquaredLogarithmicError, + Hinge, + SquaredHinge, + CategoricalHinge, + mean_squared_error, + mean_absolute_error, + mean_absolute_percentage_error, + mean_squared_logarithmic_error, + hinge, + squared_hinge, + categorical_hinge, +} + +ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} -def deserialize(obj): - raise NotImplementedError +@keras_core_export("keras_core.losses.serialize") +def serialize(loss): + """Serializes loss function or `Loss` instance. + + Args: + loss: A Keras `Loss` instance or a loss function. + + Returns: + Loss configuration dictionary. + """ + return serialization_lib.serialize_keras_object(loss) + + +@keras_core_export("keras_core.losses.deserialize") +def deserialize(name, custom_objects=None): + """Deserializes a serialized loss class/function instance. + + Args: + name: Loss configuration. + custom_objects: Optional dictionary mapping names (strings) to custom + objects (classes and functions) to be considered during + deserialization. + + Returns: + A Keras `Loss` instance or a loss function. + """ + return serialization_lib.deserialize_keras_object( + name, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) @keras_core_export("keras_core.losses.get") diff --git a/keras_core/losses/loss.py b/keras_core/losses/loss.py index 9938dab91..5a6f86cbf 100644 --- a/keras_core/losses/loss.py +++ b/keras_core/losses/loss.py @@ -45,7 +45,7 @@ class Loss: raise NotImplementedError def get_config(self): - return {"name": self.name} + return {"name": self.name, "reduction": self.reduction} @classmethod def from_config(cls, config): diff --git a/keras_core/losses/losses.py b/keras_core/losses/losses.py index ed4ea84d1..e91b24765 100644 --- a/keras_core/losses/losses.py +++ b/keras_core/losses/losses.py @@ -3,6 +3,7 @@ from keras_core import operations as ops from keras_core.api_export import keras_core_export from keras_core.losses.loss import Loss from keras_core.losses.loss import squeeze_to_same_rank +from keras_core.saving import serialization_lib class LossFunctionWrapper(Loss): @@ -18,11 +19,16 @@ class LossFunctionWrapper(Loss): return self.fn(y_true, y_pred, **self._fn_kwargs) def get_config(self): - raise NotImplementedError + base_config = super().get_config() + config = {"fn": serialization_lib.serialize_keras_object(self.fn)} + config.update(serialization_lib.serialize_keras_object(self._fn_kwargs)) + return {**base_config, **config} @classmethod - def from_config(clf, config): - raise NotImplementedError + def from_config(cls, config): + if "fn" in config: + config = serialization_lib.deserialize_keras_object(config) + return cls(**config) @keras_core_export("keras_core.losses.MeanSquaredError") @@ -47,6 +53,9 @@ class MeanSquaredError(LossFunctionWrapper): ): super().__init__(mean_squared_error, reduction=reduction, name=name) + def get_config(self): + return Loss.get_config(self) + @keras_core_export("keras_core.losses.MeanAbsoluteError") class MeanAbsoluteError(LossFunctionWrapper): @@ -70,6 +79,9 @@ class MeanAbsoluteError(LossFunctionWrapper): ): super().__init__(mean_absolute_error, reduction=reduction, name=name) + def get_config(self): + return Loss.get_config(self) + @keras_core_export("keras_core.losses.MeanAbsolutePercentageError") class MeanAbsolutePercentageError(LossFunctionWrapper): @@ -97,6 +109,9 @@ class MeanAbsolutePercentageError(LossFunctionWrapper): mean_absolute_percentage_error, reduction=reduction, name=name ) + def get_config(self): + return Loss.get_config(self) + @keras_core_export("keras_core.losses.MeanSquaredLogarithmicError") class MeanSquaredLogarithmicError(LossFunctionWrapper): @@ -124,6 +139,9 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper): mean_squared_logarithmic_error, reduction=reduction, name=name ) + def get_config(self): + return Loss.get_config(self) + @keras_core_export("keras_core.losses.Hinge") class Hinge(LossFunctionWrapper): @@ -148,6 +166,9 @@ class Hinge(LossFunctionWrapper): def __init__(self, reduction="sum_over_batch_size", name="hinge"): super().__init__(hinge, reduction=reduction, name=name) + def get_config(self): + return Loss.get_config(self) + @keras_core_export("keras_core.losses.SquaredHinge") class SquaredHinge(LossFunctionWrapper): @@ -172,6 +193,9 @@ class SquaredHinge(LossFunctionWrapper): def __init__(self, reduction="sum_over_batch_size", name="squared_hinge"): super().__init__(squared_hinge, reduction=reduction, name=name) + def get_config(self): + return Loss.get_config(self) + @keras_core_export("keras_core.losses.CategoricalHinge") class CategoricalHinge(LossFunctionWrapper): @@ -198,6 +222,9 @@ class CategoricalHinge(LossFunctionWrapper): ): super().__init__(categorical_hinge, reduction=reduction, name=name) + def get_config(self): + return Loss.get_config(self) + def convert_binary_labels_to_hinge(y_true): """Converts binary labels into -1/1 for hinge loss/metric calculation.""" diff --git a/keras_core/testing/test_case.py b/keras_core/testing/test_case.py index f60939ab7..80db96175 100644 --- a/keras_core/testing/test_case.py +++ b/keras_core/testing/test_case.py @@ -1,4 +1,6 @@ import json +import os +import tempfile import unittest import numpy as np @@ -8,6 +10,11 @@ from tensorflow import nest class TestCase(unittest.TestCase): maxDiff = None + def get_temp_dir(self): + temp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: os.rmdir(temp_dir)) + return temp_dir + def assertAllClose(self, x1, x2, atol=1e-7, rtol=1e-7): np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol)