Store trainable and dtype on a layer persistently (#283)

This is something tf.keras will do.

I'm not totally sure about dtype, we could only save it if it diverges
from the global policy. But when it is explicitly set on the layer, it
is probably important to persist.
This commit is contained in:
Matt Watson 2023-06-06 21:12:46 -07:00 committed by Francois Chollet
parent ab030e3f18
commit ea98b39996
5 changed files with 26 additions and 9 deletions

@ -33,6 +33,7 @@ from keras_core.backend.common import global_state
from keras_core.layers import input_spec
from keras_core.metrics.metric import Metric
from keras_core.operations.operation import Operation
from keras_core.utils import python_utils
from keras_core.utils import summary_utils
from keras_core.utils import traceback_utils
from keras_core.utils import tracking
@ -1121,6 +1122,15 @@ class Layer(BackendLayer, Operation):
# It's a C type.
pass
@python_utils.default
def get_config(self):
base_config = super().get_config()
config = {
"trainable": self.trainable,
"dtype": self.dtype_policy.name,
}
return {**base_config, **config}
def is_backend_tensor_or_symbolic(x):
return backend.is_tensor(x) or isinstance(x, backend.KerasTensor)

@ -38,10 +38,8 @@ class Dropout(Layer):
training mode (adding dropout) or in inference mode (doing nothing).
"""
def __init__(
self, rate, noise_shape=None, seed=None, name=None, dtype=None
):
super().__init__(name=name, dtype=dtype)
def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
super().__init__(**kwargs)
if not 0 <= rate <= 1:
raise ValueError(
f"Invalid value received for argument "

@ -24,8 +24,8 @@ class GaussianDropout(layers.Layer):
training mode (adding dropout) or in inference mode (doing nothing).
"""
def __init__(self, rate, seed=None, name=None, dtype=None):
super().__init__(name=name, dtype=dtype)
def __init__(self, rate, seed=None, **kwargs):
super().__init__(**kwargs)
if not 0 <= rate <= 1:
raise ValueError(
f"Invalid value received for argument "

@ -25,8 +25,8 @@ class GaussianNoise(layers.Layer):
training mode (adding noise) or in inference mode (doing nothing).
"""
def __init__(self, stddev, seed=None, name=None, dtype=None):
super().__init__(name=name, dtype=dtype)
def __init__(self, stddev, seed=None, **kwargs):
super().__init__(**kwargs)
if not 0 <= stddev <= 1:
raise ValueError(
f"Invalid value received for argument "

@ -87,8 +87,17 @@ class SerializationLibTest(testing.TestCase):
self.assertEqual(serialized, reserialized)
def test_builtin_layers(self):
serialized, _, reserialized = self.roundtrip(keras_core.layers.Dense(3))
layer = keras_core.layers.Dense(
3,
name="foo",
trainable=False,
dtype="float16",
)
serialized, restored, reserialized = self.roundtrip(layer)
self.assertEqual(serialized, reserialized)
self.assertEqual(layer.name, restored.name)
self.assertEqual(layer.trainable, restored.trainable)
self.assertEqual(layer.compute_dtype, restored.compute_dtype)
def test_tensors_and_shapes(self):
x = ops.random.normal((2, 2), dtype="float64")