diff --git a/keras/layers/layer.py b/keras/layers/layer.py index a7e23ef74..069fe3dd7 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -445,8 +445,8 @@ class Layer(BackendLayer, Operation): def add_weight( self, - shape, - initializer, + shape=None, + initializer=None, dtype=None, trainable=True, regularizer=None, @@ -458,12 +458,19 @@ class Layer(BackendLayer, Operation): Args: shape: Shape tuple for the variable. Must be fully-defined (no `None` entries). + Defaults to `()` (scalar) if unspecified. initializer: Initializer object to use to populate the initial variable value, or string name of a built-in initializer - (e.g. `"random_normal"`). + (e.g. `"random_normal"`). If unspecified, + defaults to `"glorot_uniform"` + for floating-point variables and to `"zeros"` + for all other types (e.g. int, bool). dtype: Dtype of the variable to create, - e.g. `"float32"`. + e.g. `"float32"`. If unspecified, + defaults to the layer's + variable dtype (which itself defaults to + `"float32"` if unspecified). trainable: Boolean, whether the variable should be trainable via backprop or whether its updates are managed manually. @@ -474,12 +481,23 @@ class Layer(BackendLayer, Operation): for debugging purposes. """ self._check_super_called() + if shape is None: + shape = () + if dtype is not None: + dtype = backend.standardize_dtype(dtype) + else: + dtype = self.variable_dtype + if initializer is None: + if "float" in dtype: + initializer = "glorot_uniform" + else: + initializer = "zeros" initializer = initializers.get(initializer) with backend.name_scope(self.name, caller=self): variable = backend.Variable( initializer=initializer, shape=shape, - dtype=dtype or self.variable_dtype, + dtype=dtype, trainable=trainable, name=name, ) diff --git a/keras/layers/layer_test.py b/keras/layers/layer_test.py index 527758838..0459fd0ab 100644 --- a/keras/layers/layer_test.py +++ b/keras/layers/layer_test.py @@ -813,3 +813,35 @@ class LayerTest(testing.TestCase): layer = MyLayer() self.assertEqual(len(layer.weights), 1) + + def test_add_weight_defaults(self): + class MyLayer(layers.Layer): + def __init__(self): + super().__init__() + self.w1 = self.add_weight() + self.w2 = self.add_weight(dtype="int32") + self.w3 = self.add_weight(dtype="bool") + self.w4 = self.add_weight(dtype="int32", shape=(2, 2)) + self.w5 = self.add_weight(initializer="ones", shape=(2, 2)) + + layer = MyLayer() + self.assertEqual(layer.w1.shape, ()) + self.assertEqual(layer.w1.dtype, "float32") + + self.assertEqual(layer.w2.shape, ()) + self.assertEqual(layer.w2.dtype, "int32") + self.assertAllClose(backend.convert_to_numpy(layer.w2), 0) + + self.assertEqual(layer.w3.shape, ()) + self.assertEqual(layer.w3.dtype, "bool") + self.assertAllClose(backend.convert_to_numpy(layer.w3), False) + + self.assertEqual(layer.w4.shape, (2, 2)) + self.assertEqual(layer.w4.dtype, "int32") + self.assertAllClose( + backend.convert_to_numpy(layer.w4), np.zeros((2, 2)) + ) + + self.assertEqual(layer.w5.shape, (2, 2)) + self.assertEqual(layer.w5.dtype, "float32") + self.assertAllClose(backend.convert_to_numpy(layer.w5), np.ones((2, 2)))