Progress on serialization

This commit is contained in:
Francois Chollet 2023-04-24 21:21:57 -07:00
parent 2c5863e8fc
commit 3788a99582
4 changed files with 104 additions and 6 deletions

@ -1,11 +1,75 @@
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.losses.loss import Loss from keras_core.losses.loss import Loss
from keras_core.losses.losses import CategoricalHinge
from keras_core.losses.losses import Hinge
from keras_core.losses.losses import LossFunctionWrapper from keras_core.losses.losses import LossFunctionWrapper
from keras_core.losses.losses import MeanAbsoluteError
from keras_core.losses.losses import MeanAbsolutePercentageError
from keras_core.losses.losses import MeanSquaredError from keras_core.losses.losses import MeanSquaredError
from keras_core.losses.losses import MeanSquaredLogarithmicError
from keras_core.losses.losses import SquaredHinge
from keras_core.losses.losses import categorical_hinge
from keras_core.losses.losses import hinge
from keras_core.losses.losses import mean_absolute_error
from keras_core.losses.losses import mean_absolute_percentage_error
from keras_core.losses.losses import mean_squared_error
from keras_core.losses.losses import mean_squared_logarithmic_error
from keras_core.losses.losses import squared_hinge
from keras_core.saving import serialization_lib
ALL_OBJECTS = {
Loss,
LossFunctionWrapper,
MeanSquaredError,
MeanAbsoluteError,
MeanAbsolutePercentageError,
MeanSquaredLogarithmicError,
Hinge,
SquaredHinge,
CategoricalHinge,
mean_squared_error,
mean_absolute_error,
mean_absolute_percentage_error,
mean_squared_logarithmic_error,
hinge,
squared_hinge,
categorical_hinge,
}
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
def deserialize(obj): @keras_core_export("keras_core.losses.serialize")
raise NotImplementedError def serialize(loss):
"""Serializes loss function or `Loss` instance.
Args:
loss: A Keras `Loss` instance or a loss function.
Returns:
Loss configuration dictionary.
"""
return serialization_lib.serialize_keras_object(loss)
@keras_core_export("keras_core.losses.deserialize")
def deserialize(name, custom_objects=None):
"""Deserializes a serialized loss class/function instance.
Args:
name: Loss configuration.
custom_objects: Optional dictionary mapping names (strings) to custom
objects (classes and functions) to be considered during
deserialization.
Returns:
A Keras `Loss` instance or a loss function.
"""
return serialization_lib.deserialize_keras_object(
name,
module_objects=ALL_OBJECTS_DICT,
custom_objects=custom_objects,
)
@keras_core_export("keras_core.losses.get") @keras_core_export("keras_core.losses.get")

@ -45,7 +45,7 @@ class Loss:
raise NotImplementedError raise NotImplementedError
def get_config(self): def get_config(self):
return {"name": self.name} return {"name": self.name, "reduction": self.reduction}
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):

@ -3,6 +3,7 @@ from keras_core import operations as ops
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.losses.loss import Loss from keras_core.losses.loss import Loss
from keras_core.losses.loss import squeeze_to_same_rank from keras_core.losses.loss import squeeze_to_same_rank
from keras_core.saving import serialization_lib
class LossFunctionWrapper(Loss): class LossFunctionWrapper(Loss):
@ -18,11 +19,16 @@ class LossFunctionWrapper(Loss):
return self.fn(y_true, y_pred, **self._fn_kwargs) return self.fn(y_true, y_pred, **self._fn_kwargs)
def get_config(self): def get_config(self):
raise NotImplementedError base_config = super().get_config()
config = {"fn": serialization_lib.serialize_keras_object(self.fn)}
config.update(serialization_lib.serialize_keras_object(self._fn_kwargs))
return {**base_config, **config}
@classmethod @classmethod
def from_config(clf, config): def from_config(cls, config):
raise NotImplementedError if "fn" in config:
config = serialization_lib.deserialize_keras_object(config)
return cls(**config)
@keras_core_export("keras_core.losses.MeanSquaredError") @keras_core_export("keras_core.losses.MeanSquaredError")
@ -47,6 +53,9 @@ class MeanSquaredError(LossFunctionWrapper):
): ):
super().__init__(mean_squared_error, reduction=reduction, name=name) super().__init__(mean_squared_error, reduction=reduction, name=name)
def get_config(self):
return Loss.get_config(self)
@keras_core_export("keras_core.losses.MeanAbsoluteError") @keras_core_export("keras_core.losses.MeanAbsoluteError")
class MeanAbsoluteError(LossFunctionWrapper): class MeanAbsoluteError(LossFunctionWrapper):
@ -70,6 +79,9 @@ class MeanAbsoluteError(LossFunctionWrapper):
): ):
super().__init__(mean_absolute_error, reduction=reduction, name=name) super().__init__(mean_absolute_error, reduction=reduction, name=name)
def get_config(self):
return Loss.get_config(self)
@keras_core_export("keras_core.losses.MeanAbsolutePercentageError") @keras_core_export("keras_core.losses.MeanAbsolutePercentageError")
class MeanAbsolutePercentageError(LossFunctionWrapper): class MeanAbsolutePercentageError(LossFunctionWrapper):
@ -97,6 +109,9 @@ class MeanAbsolutePercentageError(LossFunctionWrapper):
mean_absolute_percentage_error, reduction=reduction, name=name mean_absolute_percentage_error, reduction=reduction, name=name
) )
def get_config(self):
return Loss.get_config(self)
@keras_core_export("keras_core.losses.MeanSquaredLogarithmicError") @keras_core_export("keras_core.losses.MeanSquaredLogarithmicError")
class MeanSquaredLogarithmicError(LossFunctionWrapper): class MeanSquaredLogarithmicError(LossFunctionWrapper):
@ -124,6 +139,9 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper):
mean_squared_logarithmic_error, reduction=reduction, name=name mean_squared_logarithmic_error, reduction=reduction, name=name
) )
def get_config(self):
return Loss.get_config(self)
@keras_core_export("keras_core.losses.Hinge") @keras_core_export("keras_core.losses.Hinge")
class Hinge(LossFunctionWrapper): class Hinge(LossFunctionWrapper):
@ -148,6 +166,9 @@ class Hinge(LossFunctionWrapper):
def __init__(self, reduction="sum_over_batch_size", name="hinge"): def __init__(self, reduction="sum_over_batch_size", name="hinge"):
super().__init__(hinge, reduction=reduction, name=name) super().__init__(hinge, reduction=reduction, name=name)
def get_config(self):
return Loss.get_config(self)
@keras_core_export("keras_core.losses.SquaredHinge") @keras_core_export("keras_core.losses.SquaredHinge")
class SquaredHinge(LossFunctionWrapper): class SquaredHinge(LossFunctionWrapper):
@ -172,6 +193,9 @@ class SquaredHinge(LossFunctionWrapper):
def __init__(self, reduction="sum_over_batch_size", name="squared_hinge"): def __init__(self, reduction="sum_over_batch_size", name="squared_hinge"):
super().__init__(squared_hinge, reduction=reduction, name=name) super().__init__(squared_hinge, reduction=reduction, name=name)
def get_config(self):
return Loss.get_config(self)
@keras_core_export("keras_core.losses.CategoricalHinge") @keras_core_export("keras_core.losses.CategoricalHinge")
class CategoricalHinge(LossFunctionWrapper): class CategoricalHinge(LossFunctionWrapper):
@ -198,6 +222,9 @@ class CategoricalHinge(LossFunctionWrapper):
): ):
super().__init__(categorical_hinge, reduction=reduction, name=name) super().__init__(categorical_hinge, reduction=reduction, name=name)
def get_config(self):
return Loss.get_config(self)
def convert_binary_labels_to_hinge(y_true): def convert_binary_labels_to_hinge(y_true):
"""Converts binary labels into -1/1 for hinge loss/metric calculation.""" """Converts binary labels into -1/1 for hinge loss/metric calculation."""

@ -1,4 +1,6 @@
import json import json
import os
import tempfile
import unittest import unittest
import numpy as np import numpy as np
@ -8,6 +10,11 @@ from tensorflow import nest
class TestCase(unittest.TestCase): class TestCase(unittest.TestCase):
maxDiff = None maxDiff = None
def get_temp_dir(self):
temp_dir = tempfile.mkdtemp()
self.addCleanup(lambda: os.rmdir(temp_dir))
return temp_dir
def assertAllClose(self, x1, x2, atol=1e-7, rtol=1e-7): def assertAllClose(self, x1, x2, atol=1e-7, rtol=1e-7):
np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol) np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol)