Add tf.keras defaults to add_weights.

This commit is contained in:
Francois Chollet 2023-10-21 02:40:42 -07:00
parent 1731c59c2f
commit e3c68cd1d4
2 changed files with 55 additions and 5 deletions

@ -445,8 +445,8 @@ class Layer(BackendLayer, Operation):
def add_weight(
self,
shape,
initializer,
shape=None,
initializer=None,
dtype=None,
trainable=True,
regularizer=None,
@ -458,12 +458,19 @@ class Layer(BackendLayer, Operation):
Args:
shape: Shape tuple for the variable.
Must be fully-defined (no `None` entries).
Defaults to `()` (scalar) if unspecified.
initializer: Initializer object to use to
populate the initial variable value,
or string name of a built-in initializer
(e.g. `"random_normal"`).
(e.g. `"random_normal"`). If unspecified,
defaults to `"glorot_uniform"`
for floating-point variables and to `"zeros"`
for all other types (e.g. int, bool).
dtype: Dtype of the variable to create,
e.g. `"float32"`.
e.g. `"float32"`. If unspecified,
defaults to the layer's
variable dtype (which itself defaults to
`"float32"` if unspecified).
trainable: Boolean, whether the variable should
be trainable via backprop or whether its
updates are managed manually.
@ -474,12 +481,23 @@ class Layer(BackendLayer, Operation):
for debugging purposes.
"""
self._check_super_called()
if shape is None:
shape = ()
if dtype is not None:
dtype = backend.standardize_dtype(dtype)
else:
dtype = self.variable_dtype
if initializer is None:
if "float" in dtype:
initializer = "glorot_uniform"
else:
initializer = "zeros"
initializer = initializers.get(initializer)
with backend.name_scope(self.name, caller=self):
variable = backend.Variable(
initializer=initializer,
shape=shape,
dtype=dtype or self.variable_dtype,
dtype=dtype,
trainable=trainable,
name=name,
)

@ -813,3 +813,35 @@ class LayerTest(testing.TestCase):
layer = MyLayer()
self.assertEqual(len(layer.weights), 1)
def test_add_weight_defaults(self):
class MyLayer(layers.Layer):
def __init__(self):
super().__init__()
self.w1 = self.add_weight()
self.w2 = self.add_weight(dtype="int32")
self.w3 = self.add_weight(dtype="bool")
self.w4 = self.add_weight(dtype="int32", shape=(2, 2))
self.w5 = self.add_weight(initializer="ones", shape=(2, 2))
layer = MyLayer()
self.assertEqual(layer.w1.shape, ())
self.assertEqual(layer.w1.dtype, "float32")
self.assertEqual(layer.w2.shape, ())
self.assertEqual(layer.w2.dtype, "int32")
self.assertAllClose(backend.convert_to_numpy(layer.w2), 0)
self.assertEqual(layer.w3.shape, ())
self.assertEqual(layer.w3.dtype, "bool")
self.assertAllClose(backend.convert_to_numpy(layer.w3), False)
self.assertEqual(layer.w4.shape, (2, 2))
self.assertEqual(layer.w4.dtype, "int32")
self.assertAllClose(
backend.convert_to_numpy(layer.w4), np.zeros((2, 2))
)
self.assertEqual(layer.w5.shape, (2, 2))
self.assertEqual(layer.w5.dtype, "float32")
self.assertAllClose(backend.convert_to_numpy(layer.w5), np.ones((2, 2)))