Add loss collection in stateless call as well stateless call tests

This commit is contained in:
Francois Chollet 2023-05-04 14:52:00 -07:00
parent 3dea0fc5b6
commit c457e2a605
3 changed files with 122 additions and 5 deletions

@ -28,7 +28,8 @@ def draw_seed(seed):
from keras_core.backend import convert_to_tensor from keras_core.backend import convert_to_tensor
if isinstance(seed, SeedGenerator): if isinstance(seed, SeedGenerator):
new_seed_value = seed.state.value # Use * 1 to create a copy
new_seed_value = seed.state.value * 1
seed.state.assign( seed.state.assign(
seed.state + convert_to_tensor([0, 1], dtype="uint32") seed.state + convert_to_tensor([0, 1], dtype="uint32")
) )

@ -415,15 +415,58 @@ class Layer(Operation):
raise NotImplementedError raise NotImplementedError
def stateless_call( def stateless_call(
self, trainable_variables, non_trainable_variables, *args, **kwargs self,
trainable_variables,
non_trainable_variables,
*args,
return_losses=False,
**kwargs,
): ):
# TODO: also handle losses """Call the layer without any side effects.
Args:
trainable_variables: List of trainable variables of the model.
non_trainable_variables: List of non-trainable variables of the model.
*args: Positional argumets to be passed to `call()`.
return_losses: If `True`, `stateless_call()` will return the list of
losses created during `call()` as part of its return values.
**kwargs: Keyword arguments to be passed to `call()`.
Returns:
A tuple. By default, returns `(outputs, non_trainable_variables)`.
If `return_losses = True`, then returns
`(outputs, non_trainable_variables, losses)`.
Note: `non_trainable_variables` include not only non-trainable weights
such as `BatchNormalization` statistics, but also RNG seed state
(if there are any random operations part of the layer, such as dropout),
and `Metric` state (if there are any metrics attached to the layer).
These are all elements of state of the layer.
Example:
```python
model = ...
data = ...
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
# Call the model with zero side effects
outputs, non_trainable_variables = model.stateless_call(
trainable_variables,
non_trainable_variables,
data,
)
# Attach the updated state to the model
# (until you do this, the model is still in its pre-call state).
for ref_var, value in zip(model.non_trainable_variables, non_trainable_variables):
ref_var.assign(value)
```
"""
self._check_super_called() self._check_super_called()
if not self.built: if not self.built:
raise ValueError( raise ValueError(
"To call stateless_call, {self.__class__.__name__} must be " f"To call stateless_call, {self.__class__.__name__} must be "
"built (i.e. its variables must have been already created). " "built (i.e. its variables must have been already created). "
"You can build it by calling it on some data." "You can build it by calling it on some data."
) )
@ -452,7 +495,9 @@ class Layer(Operation):
mapping = list(trainable_mapping) + list(non_trainable_mapping) mapping = list(trainable_mapping) + list(non_trainable_mapping)
# Call in stateless scope # Call in stateless scope
with backend.StatelessScope(state_mapping=mapping) as scope: with backend.StatelessScope(
state_mapping=mapping, collect_losses=return_losses
) as scope:
outputs = self.call(*args, **kwargs) outputs = self.call(*args, **kwargs)
# Gather updated non-trainable variables # Gather updated non-trainable variables
@ -463,6 +508,9 @@ class Layer(Operation):
non_trainable_variables.append(new_v) non_trainable_variables.append(new_v)
else: else:
non_trainable_variables.append(v) non_trainable_variables.append(v)
if return_losses:
return outputs, non_trainable_variables, scope.losses[:]
return outputs, non_trainable_variables return outputs, non_trainable_variables
def compute_output_spec(self, *args, **kwargs): def compute_output_spec(self, *args, **kwargs):

@ -332,3 +332,71 @@ class LayerTest(testing.TestCase):
x2._keras_mask = backend.numpy.ones((4,)) x2._keras_mask = backend.numpy.ones((4,))
layer((x1_1, x1_2), x2) layer((x1_1, x1_2), x2)
layer(x1=(x1_1, x1_2), x2=x2) layer(x1=(x1_1, x1_2), x2=x2)
def test_stateless_call(self):
class TestLayer(layers.Layer):
def __init__(self):
super().__init__()
self._seed_generator = backend.random.SeedGenerator(1337)
self.ntw = self.add_weight(
shape=(),
initializer="zeros",
trainable=False,
)
self.tw = self.add_weight(
shape=(),
initializer="zeros",
trainable=True,
)
self.built = True
def call(self, x):
x = backend.convert_to_tensor(x, dtype="float32")
self.add_loss(ops.sum(x))
self.ntw.assign(ops.sum(x))
x = x + backend.random.normal(
shape=(), seed=self._seed_generator
)
return x + self.tw + self.ntw
data = np.random.random((3, 4))
layer = TestLayer()
out = layer(data)
layer1 = TestLayer()
out1 = layer1(data)
# Check that the layer is in fact deterministic
self.assertAllClose(out, out1)
# Test stateless_call correctness
layer2 = TestLayer()
trainable_variables = layer2.trainable_variables
non_trainable_variables = layer2.non_trainable_variables
out2, non_trainable_variables = layer2.stateless_call(
trainable_variables, non_trainable_variables, data
)
self.assertAllClose(out1, out2)
self.assertEqual(
len(layer1.non_trainable_variables), len(non_trainable_variables)
)
for ref_v, v in zip(
layer1.non_trainable_variables, non_trainable_variables
):
self.assertAllClose(ref_v, v)
# Test with loss collection
layer3 = TestLayer()
trainable_variables = layer3.trainable_variables
non_trainable_variables = layer3.non_trainable_variables
out3, non_trainable_variables, losses = layer3.stateless_call(
trainable_variables,
non_trainable_variables,
data,
return_losses=True,
)
self.assertAllClose(out1, out3)
for ref_v, v in zip(
layer1.non_trainable_variables, non_trainable_variables
):
self.assertAllClose(ref_v, v)
for ref_loss, loss in zip(layer1.losses, losses):
self.assertAllClose(ref_loss, loss)