Progress on serialization
This commit is contained in:
parent
2c5863e8fc
commit
3788a99582
@ -1,11 +1,75 @@
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.losses.loss import Loss
|
||||
from keras_core.losses.losses import CategoricalHinge
|
||||
from keras_core.losses.losses import Hinge
|
||||
from keras_core.losses.losses import LossFunctionWrapper
|
||||
from keras_core.losses.losses import MeanAbsoluteError
|
||||
from keras_core.losses.losses import MeanAbsolutePercentageError
|
||||
from keras_core.losses.losses import MeanSquaredError
|
||||
from keras_core.losses.losses import MeanSquaredLogarithmicError
|
||||
from keras_core.losses.losses import SquaredHinge
|
||||
from keras_core.losses.losses import categorical_hinge
|
||||
from keras_core.losses.losses import hinge
|
||||
from keras_core.losses.losses import mean_absolute_error
|
||||
from keras_core.losses.losses import mean_absolute_percentage_error
|
||||
from keras_core.losses.losses import mean_squared_error
|
||||
from keras_core.losses.losses import mean_squared_logarithmic_error
|
||||
from keras_core.losses.losses import squared_hinge
|
||||
from keras_core.saving import serialization_lib
|
||||
|
||||
ALL_OBJECTS = {
|
||||
Loss,
|
||||
LossFunctionWrapper,
|
||||
MeanSquaredError,
|
||||
MeanAbsoluteError,
|
||||
MeanAbsolutePercentageError,
|
||||
MeanSquaredLogarithmicError,
|
||||
Hinge,
|
||||
SquaredHinge,
|
||||
CategoricalHinge,
|
||||
mean_squared_error,
|
||||
mean_absolute_error,
|
||||
mean_absolute_percentage_error,
|
||||
mean_squared_logarithmic_error,
|
||||
hinge,
|
||||
squared_hinge,
|
||||
categorical_hinge,
|
||||
}
|
||||
|
||||
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
|
||||
|
||||
|
||||
def deserialize(obj):
|
||||
raise NotImplementedError
|
||||
@keras_core_export("keras_core.losses.serialize")
|
||||
def serialize(loss):
|
||||
"""Serializes loss function or `Loss` instance.
|
||||
|
||||
Args:
|
||||
loss: A Keras `Loss` instance or a loss function.
|
||||
|
||||
Returns:
|
||||
Loss configuration dictionary.
|
||||
"""
|
||||
return serialization_lib.serialize_keras_object(loss)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.losses.deserialize")
|
||||
def deserialize(name, custom_objects=None):
|
||||
"""Deserializes a serialized loss class/function instance.
|
||||
|
||||
Args:
|
||||
name: Loss configuration.
|
||||
custom_objects: Optional dictionary mapping names (strings) to custom
|
||||
objects (classes and functions) to be considered during
|
||||
deserialization.
|
||||
|
||||
Returns:
|
||||
A Keras `Loss` instance or a loss function.
|
||||
"""
|
||||
return serialization_lib.deserialize_keras_object(
|
||||
name,
|
||||
module_objects=ALL_OBJECTS_DICT,
|
||||
custom_objects=custom_objects,
|
||||
)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.losses.get")
|
||||
|
@ -45,7 +45,7 @@ class Loss:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_config(self):
|
||||
return {"name": self.name}
|
||||
return {"name": self.name, "reduction": self.reduction}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
|
@ -3,6 +3,7 @@ from keras_core import operations as ops
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.losses.loss import Loss
|
||||
from keras_core.losses.loss import squeeze_to_same_rank
|
||||
from keras_core.saving import serialization_lib
|
||||
|
||||
|
||||
class LossFunctionWrapper(Loss):
|
||||
@ -18,11 +19,16 @@ class LossFunctionWrapper(Loss):
|
||||
return self.fn(y_true, y_pred, **self._fn_kwargs)
|
||||
|
||||
def get_config(self):
|
||||
raise NotImplementedError
|
||||
base_config = super().get_config()
|
||||
config = {"fn": serialization_lib.serialize_keras_object(self.fn)}
|
||||
config.update(serialization_lib.serialize_keras_object(self._fn_kwargs))
|
||||
return {**base_config, **config}
|
||||
|
||||
@classmethod
|
||||
def from_config(clf, config):
|
||||
raise NotImplementedError
|
||||
def from_config(cls, config):
|
||||
if "fn" in config:
|
||||
config = serialization_lib.deserialize_keras_object(config)
|
||||
return cls(**config)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.losses.MeanSquaredError")
|
||||
@ -47,6 +53,9 @@ class MeanSquaredError(LossFunctionWrapper):
|
||||
):
|
||||
super().__init__(mean_squared_error, reduction=reduction, name=name)
|
||||
|
||||
def get_config(self):
|
||||
return Loss.get_config(self)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.losses.MeanAbsoluteError")
|
||||
class MeanAbsoluteError(LossFunctionWrapper):
|
||||
@ -70,6 +79,9 @@ class MeanAbsoluteError(LossFunctionWrapper):
|
||||
):
|
||||
super().__init__(mean_absolute_error, reduction=reduction, name=name)
|
||||
|
||||
def get_config(self):
|
||||
return Loss.get_config(self)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.losses.MeanAbsolutePercentageError")
|
||||
class MeanAbsolutePercentageError(LossFunctionWrapper):
|
||||
@ -97,6 +109,9 @@ class MeanAbsolutePercentageError(LossFunctionWrapper):
|
||||
mean_absolute_percentage_error, reduction=reduction, name=name
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
return Loss.get_config(self)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.losses.MeanSquaredLogarithmicError")
|
||||
class MeanSquaredLogarithmicError(LossFunctionWrapper):
|
||||
@ -124,6 +139,9 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper):
|
||||
mean_squared_logarithmic_error, reduction=reduction, name=name
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
return Loss.get_config(self)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.losses.Hinge")
|
||||
class Hinge(LossFunctionWrapper):
|
||||
@ -148,6 +166,9 @@ class Hinge(LossFunctionWrapper):
|
||||
def __init__(self, reduction="sum_over_batch_size", name="hinge"):
|
||||
super().__init__(hinge, reduction=reduction, name=name)
|
||||
|
||||
def get_config(self):
|
||||
return Loss.get_config(self)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.losses.SquaredHinge")
|
||||
class SquaredHinge(LossFunctionWrapper):
|
||||
@ -172,6 +193,9 @@ class SquaredHinge(LossFunctionWrapper):
|
||||
def __init__(self, reduction="sum_over_batch_size", name="squared_hinge"):
|
||||
super().__init__(squared_hinge, reduction=reduction, name=name)
|
||||
|
||||
def get_config(self):
|
||||
return Loss.get_config(self)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.losses.CategoricalHinge")
|
||||
class CategoricalHinge(LossFunctionWrapper):
|
||||
@ -198,6 +222,9 @@ class CategoricalHinge(LossFunctionWrapper):
|
||||
):
|
||||
super().__init__(categorical_hinge, reduction=reduction, name=name)
|
||||
|
||||
def get_config(self):
|
||||
return Loss.get_config(self)
|
||||
|
||||
|
||||
def convert_binary_labels_to_hinge(y_true):
|
||||
"""Converts binary labels into -1/1 for hinge loss/metric calculation."""
|
||||
|
@ -1,4 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@ -8,6 +10,11 @@ from tensorflow import nest
|
||||
class TestCase(unittest.TestCase):
|
||||
maxDiff = None
|
||||
|
||||
def get_temp_dir(self):
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: os.rmdir(temp_dir))
|
||||
return temp_dir
|
||||
|
||||
def assertAllClose(self, x1, x2, atol=1e-7, rtol=1e-7):
|
||||
np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user