Add tf.keras defaults to add_weights.
This commit is contained in:
parent
1731c59c2f
commit
e3c68cd1d4
@ -445,8 +445,8 @@ class Layer(BackendLayer, Operation):
|
||||
|
||||
def add_weight(
|
||||
self,
|
||||
shape,
|
||||
initializer,
|
||||
shape=None,
|
||||
initializer=None,
|
||||
dtype=None,
|
||||
trainable=True,
|
||||
regularizer=None,
|
||||
@ -458,12 +458,19 @@ class Layer(BackendLayer, Operation):
|
||||
Args:
|
||||
shape: Shape tuple for the variable.
|
||||
Must be fully-defined (no `None` entries).
|
||||
Defaults to `()` (scalar) if unspecified.
|
||||
initializer: Initializer object to use to
|
||||
populate the initial variable value,
|
||||
or string name of a built-in initializer
|
||||
(e.g. `"random_normal"`).
|
||||
(e.g. `"random_normal"`). If unspecified,
|
||||
defaults to `"glorot_uniform"`
|
||||
for floating-point variables and to `"zeros"`
|
||||
for all other types (e.g. int, bool).
|
||||
dtype: Dtype of the variable to create,
|
||||
e.g. `"float32"`.
|
||||
e.g. `"float32"`. If unspecified,
|
||||
defaults to the layer's
|
||||
variable dtype (which itself defaults to
|
||||
`"float32"` if unspecified).
|
||||
trainable: Boolean, whether the variable should
|
||||
be trainable via backprop or whether its
|
||||
updates are managed manually.
|
||||
@ -474,12 +481,23 @@ class Layer(BackendLayer, Operation):
|
||||
for debugging purposes.
|
||||
"""
|
||||
self._check_super_called()
|
||||
if shape is None:
|
||||
shape = ()
|
||||
if dtype is not None:
|
||||
dtype = backend.standardize_dtype(dtype)
|
||||
else:
|
||||
dtype = self.variable_dtype
|
||||
if initializer is None:
|
||||
if "float" in dtype:
|
||||
initializer = "glorot_uniform"
|
||||
else:
|
||||
initializer = "zeros"
|
||||
initializer = initializers.get(initializer)
|
||||
with backend.name_scope(self.name, caller=self):
|
||||
variable = backend.Variable(
|
||||
initializer=initializer,
|
||||
shape=shape,
|
||||
dtype=dtype or self.variable_dtype,
|
||||
dtype=dtype,
|
||||
trainable=trainable,
|
||||
name=name,
|
||||
)
|
||||
|
@ -813,3 +813,35 @@ class LayerTest(testing.TestCase):
|
||||
|
||||
layer = MyLayer()
|
||||
self.assertEqual(len(layer.weights), 1)
|
||||
|
||||
def test_add_weight_defaults(self):
|
||||
class MyLayer(layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.w1 = self.add_weight()
|
||||
self.w2 = self.add_weight(dtype="int32")
|
||||
self.w3 = self.add_weight(dtype="bool")
|
||||
self.w4 = self.add_weight(dtype="int32", shape=(2, 2))
|
||||
self.w5 = self.add_weight(initializer="ones", shape=(2, 2))
|
||||
|
||||
layer = MyLayer()
|
||||
self.assertEqual(layer.w1.shape, ())
|
||||
self.assertEqual(layer.w1.dtype, "float32")
|
||||
|
||||
self.assertEqual(layer.w2.shape, ())
|
||||
self.assertEqual(layer.w2.dtype, "int32")
|
||||
self.assertAllClose(backend.convert_to_numpy(layer.w2), 0)
|
||||
|
||||
self.assertEqual(layer.w3.shape, ())
|
||||
self.assertEqual(layer.w3.dtype, "bool")
|
||||
self.assertAllClose(backend.convert_to_numpy(layer.w3), False)
|
||||
|
||||
self.assertEqual(layer.w4.shape, (2, 2))
|
||||
self.assertEqual(layer.w4.dtype, "int32")
|
||||
self.assertAllClose(
|
||||
backend.convert_to_numpy(layer.w4), np.zeros((2, 2))
|
||||
)
|
||||
|
||||
self.assertEqual(layer.w5.shape, (2, 2))
|
||||
self.assertEqual(layer.w5.dtype, "float32")
|
||||
self.assertAllClose(backend.convert_to_numpy(layer.w5), np.ones((2, 2)))
|
||||
|
Loading…
Reference in New Issue
Block a user