Store trainable and dtype on a layer persistently (#283)
This is something tf.keras will do. I'm not totally sure about dtype, we could only save it if it diverges from the global policy. But when it is explicitly set on the layer, it is probably important to persist.
This commit is contained in:
parent
ab030e3f18
commit
ea98b39996
@ -33,6 +33,7 @@ from keras_core.backend.common import global_state
|
||||
from keras_core.layers import input_spec
|
||||
from keras_core.metrics.metric import Metric
|
||||
from keras_core.operations.operation import Operation
|
||||
from keras_core.utils import python_utils
|
||||
from keras_core.utils import summary_utils
|
||||
from keras_core.utils import traceback_utils
|
||||
from keras_core.utils import tracking
|
||||
@ -1121,6 +1122,15 @@ class Layer(BackendLayer, Operation):
|
||||
# It's a C type.
|
||||
pass
|
||||
|
||||
@python_utils.default
|
||||
def get_config(self):
|
||||
base_config = super().get_config()
|
||||
config = {
|
||||
"trainable": self.trainable,
|
||||
"dtype": self.dtype_policy.name,
|
||||
}
|
||||
return {**base_config, **config}
|
||||
|
||||
|
||||
def is_backend_tensor_or_symbolic(x):
|
||||
return backend.is_tensor(x) or isinstance(x, backend.KerasTensor)
|
||||
|
@ -38,10 +38,8 @@ class Dropout(Layer):
|
||||
training mode (adding dropout) or in inference mode (doing nothing).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, rate, noise_shape=None, seed=None, name=None, dtype=None
|
||||
):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not 0 <= rate <= 1:
|
||||
raise ValueError(
|
||||
f"Invalid value received for argument "
|
||||
|
@ -24,8 +24,8 @@ class GaussianDropout(layers.Layer):
|
||||
training mode (adding dropout) or in inference mode (doing nothing).
|
||||
"""
|
||||
|
||||
def __init__(self, rate, seed=None, name=None, dtype=None):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
def __init__(self, rate, seed=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not 0 <= rate <= 1:
|
||||
raise ValueError(
|
||||
f"Invalid value received for argument "
|
||||
|
@ -25,8 +25,8 @@ class GaussianNoise(layers.Layer):
|
||||
training mode (adding noise) or in inference mode (doing nothing).
|
||||
"""
|
||||
|
||||
def __init__(self, stddev, seed=None, name=None, dtype=None):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
def __init__(self, stddev, seed=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not 0 <= stddev <= 1:
|
||||
raise ValueError(
|
||||
f"Invalid value received for argument "
|
||||
|
@ -87,8 +87,17 @@ class SerializationLibTest(testing.TestCase):
|
||||
self.assertEqual(serialized, reserialized)
|
||||
|
||||
def test_builtin_layers(self):
|
||||
serialized, _, reserialized = self.roundtrip(keras_core.layers.Dense(3))
|
||||
layer = keras_core.layers.Dense(
|
||||
3,
|
||||
name="foo",
|
||||
trainable=False,
|
||||
dtype="float16",
|
||||
)
|
||||
serialized, restored, reserialized = self.roundtrip(layer)
|
||||
self.assertEqual(serialized, reserialized)
|
||||
self.assertEqual(layer.name, restored.name)
|
||||
self.assertEqual(layer.trainable, restored.trainable)
|
||||
self.assertEqual(layer.compute_dtype, restored.compute_dtype)
|
||||
|
||||
def test_tensors_and_shapes(self):
|
||||
x = ops.random.normal((2, 2), dtype="float64")
|
||||
|
Loading…
Reference in New Issue
Block a user