From ea98b399961fc85442ef297de5b66420cdfe6d82 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Tue, 6 Jun 2023 21:12:46 -0700 Subject: [PATCH] Store trainable and dtype on a layer persistently (#283) This is something tf.keras will do. I'm not totally sure about dtype, we could only save it if it diverges from the global policy. But when it is explicitly set on the layer, it is probably important to persist. --- keras_core/layers/layer.py | 10 ++++++++++ keras_core/layers/regularization/dropout.py | 6 ++---- keras_core/layers/regularization/gaussian_dropout.py | 4 ++-- keras_core/layers/regularization/gaussian_noise.py | 4 ++-- keras_core/saving/serialization_lib_test.py | 11 ++++++++++- 5 files changed, 26 insertions(+), 9 deletions(-) diff --git a/keras_core/layers/layer.py b/keras_core/layers/layer.py index 776bbe7ce..ff2a3c897 100644 --- a/keras_core/layers/layer.py +++ b/keras_core/layers/layer.py @@ -33,6 +33,7 @@ from keras_core.backend.common import global_state from keras_core.layers import input_spec from keras_core.metrics.metric import Metric from keras_core.operations.operation import Operation +from keras_core.utils import python_utils from keras_core.utils import summary_utils from keras_core.utils import traceback_utils from keras_core.utils import tracking @@ -1121,6 +1122,15 @@ class Layer(BackendLayer, Operation): # It's a C type. pass + @python_utils.default + def get_config(self): + base_config = super().get_config() + config = { + "trainable": self.trainable, + "dtype": self.dtype_policy.name, + } + return {**base_config, **config} + def is_backend_tensor_or_symbolic(x): return backend.is_tensor(x) or isinstance(x, backend.KerasTensor) diff --git a/keras_core/layers/regularization/dropout.py b/keras_core/layers/regularization/dropout.py index 6862b83b5..88e8fc8d0 100644 --- a/keras_core/layers/regularization/dropout.py +++ b/keras_core/layers/regularization/dropout.py @@ -38,10 +38,8 @@ class Dropout(Layer): training mode (adding dropout) or in inference mode (doing nothing). """ - def __init__( - self, rate, noise_shape=None, seed=None, name=None, dtype=None - ): - super().__init__(name=name, dtype=dtype) + def __init__(self, rate, noise_shape=None, seed=None, **kwargs): + super().__init__(**kwargs) if not 0 <= rate <= 1: raise ValueError( f"Invalid value received for argument " diff --git a/keras_core/layers/regularization/gaussian_dropout.py b/keras_core/layers/regularization/gaussian_dropout.py index bee22320e..66d682517 100644 --- a/keras_core/layers/regularization/gaussian_dropout.py +++ b/keras_core/layers/regularization/gaussian_dropout.py @@ -24,8 +24,8 @@ class GaussianDropout(layers.Layer): training mode (adding dropout) or in inference mode (doing nothing). """ - def __init__(self, rate, seed=None, name=None, dtype=None): - super().__init__(name=name, dtype=dtype) + def __init__(self, rate, seed=None, **kwargs): + super().__init__(**kwargs) if not 0 <= rate <= 1: raise ValueError( f"Invalid value received for argument " diff --git a/keras_core/layers/regularization/gaussian_noise.py b/keras_core/layers/regularization/gaussian_noise.py index 66466cd0e..35fca69c5 100644 --- a/keras_core/layers/regularization/gaussian_noise.py +++ b/keras_core/layers/regularization/gaussian_noise.py @@ -25,8 +25,8 @@ class GaussianNoise(layers.Layer): training mode (adding noise) or in inference mode (doing nothing). """ - def __init__(self, stddev, seed=None, name=None, dtype=None): - super().__init__(name=name, dtype=dtype) + def __init__(self, stddev, seed=None, **kwargs): + super().__init__(**kwargs) if not 0 <= stddev <= 1: raise ValueError( f"Invalid value received for argument " diff --git a/keras_core/saving/serialization_lib_test.py b/keras_core/saving/serialization_lib_test.py index 02fe105b9..fc45caf93 100644 --- a/keras_core/saving/serialization_lib_test.py +++ b/keras_core/saving/serialization_lib_test.py @@ -87,8 +87,17 @@ class SerializationLibTest(testing.TestCase): self.assertEqual(serialized, reserialized) def test_builtin_layers(self): - serialized, _, reserialized = self.roundtrip(keras_core.layers.Dense(3)) + layer = keras_core.layers.Dense( + 3, + name="foo", + trainable=False, + dtype="float16", + ) + serialized, restored, reserialized = self.roundtrip(layer) self.assertEqual(serialized, reserialized) + self.assertEqual(layer.name, restored.name) + self.assertEqual(layer.trainable, restored.trainable) + self.assertEqual(layer.compute_dtype, restored.compute_dtype) def test_tensors_and_shapes(self): x = ops.random.normal((2, 2), dtype="float64")