Add loss collection in stateless call as well stateless call tests
This commit is contained in:
parent
3dea0fc5b6
commit
c457e2a605
@ -28,7 +28,8 @@ def draw_seed(seed):
|
||||
from keras_core.backend import convert_to_tensor
|
||||
|
||||
if isinstance(seed, SeedGenerator):
|
||||
new_seed_value = seed.state.value
|
||||
# Use * 1 to create a copy
|
||||
new_seed_value = seed.state.value * 1
|
||||
seed.state.assign(
|
||||
seed.state + convert_to_tensor([0, 1], dtype="uint32")
|
||||
)
|
||||
|
@ -415,15 +415,58 @@ class Layer(Operation):
|
||||
raise NotImplementedError
|
||||
|
||||
def stateless_call(
|
||||
self, trainable_variables, non_trainable_variables, *args, **kwargs
|
||||
self,
|
||||
trainable_variables,
|
||||
non_trainable_variables,
|
||||
*args,
|
||||
return_losses=False,
|
||||
**kwargs,
|
||||
):
|
||||
# TODO: also handle losses
|
||||
"""Call the layer without any side effects.
|
||||
|
||||
Args:
|
||||
trainable_variables: List of trainable variables of the model.
|
||||
non_trainable_variables: List of non-trainable variables of the model.
|
||||
*args: Positional argumets to be passed to `call()`.
|
||||
return_losses: If `True`, `stateless_call()` will return the list of
|
||||
losses created during `call()` as part of its return values.
|
||||
**kwargs: Keyword arguments to be passed to `call()`.
|
||||
|
||||
Returns:
|
||||
A tuple. By default, returns `(outputs, non_trainable_variables)`.
|
||||
If `return_losses = True`, then returns
|
||||
`(outputs, non_trainable_variables, losses)`.
|
||||
|
||||
Note: `non_trainable_variables` include not only non-trainable weights
|
||||
such as `BatchNormalization` statistics, but also RNG seed state
|
||||
(if there are any random operations part of the layer, such as dropout),
|
||||
and `Metric` state (if there are any metrics attached to the layer).
|
||||
These are all elements of state of the layer.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
model = ...
|
||||
data = ...
|
||||
trainable_variables = model.trainable_variables
|
||||
non_trainable_variables = model.non_trainable_variables
|
||||
# Call the model with zero side effects
|
||||
outputs, non_trainable_variables = model.stateless_call(
|
||||
trainable_variables,
|
||||
non_trainable_variables,
|
||||
data,
|
||||
)
|
||||
# Attach the updated state to the model
|
||||
# (until you do this, the model is still in its pre-call state).
|
||||
for ref_var, value in zip(model.non_trainable_variables, non_trainable_variables):
|
||||
ref_var.assign(value)
|
||||
```
|
||||
"""
|
||||
self._check_super_called()
|
||||
|
||||
if not self.built:
|
||||
raise ValueError(
|
||||
"To call stateless_call, {self.__class__.__name__} must be "
|
||||
f"To call stateless_call, {self.__class__.__name__} must be "
|
||||
"built (i.e. its variables must have been already created). "
|
||||
"You can build it by calling it on some data."
|
||||
)
|
||||
@ -452,7 +495,9 @@ class Layer(Operation):
|
||||
mapping = list(trainable_mapping) + list(non_trainable_mapping)
|
||||
|
||||
# Call in stateless scope
|
||||
with backend.StatelessScope(state_mapping=mapping) as scope:
|
||||
with backend.StatelessScope(
|
||||
state_mapping=mapping, collect_losses=return_losses
|
||||
) as scope:
|
||||
outputs = self.call(*args, **kwargs)
|
||||
|
||||
# Gather updated non-trainable variables
|
||||
@ -463,6 +508,9 @@ class Layer(Operation):
|
||||
non_trainable_variables.append(new_v)
|
||||
else:
|
||||
non_trainable_variables.append(v)
|
||||
|
||||
if return_losses:
|
||||
return outputs, non_trainable_variables, scope.losses[:]
|
||||
return outputs, non_trainable_variables
|
||||
|
||||
def compute_output_spec(self, *args, **kwargs):
|
||||
|
@ -332,3 +332,71 @@ class LayerTest(testing.TestCase):
|
||||
x2._keras_mask = backend.numpy.ones((4,))
|
||||
layer((x1_1, x1_2), x2)
|
||||
layer(x1=(x1_1, x1_2), x2=x2)
|
||||
|
||||
def test_stateless_call(self):
|
||||
class TestLayer(layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._seed_generator = backend.random.SeedGenerator(1337)
|
||||
self.ntw = self.add_weight(
|
||||
shape=(),
|
||||
initializer="zeros",
|
||||
trainable=False,
|
||||
)
|
||||
self.tw = self.add_weight(
|
||||
shape=(),
|
||||
initializer="zeros",
|
||||
trainable=True,
|
||||
)
|
||||
self.built = True
|
||||
|
||||
def call(self, x):
|
||||
x = backend.convert_to_tensor(x, dtype="float32")
|
||||
self.add_loss(ops.sum(x))
|
||||
self.ntw.assign(ops.sum(x))
|
||||
x = x + backend.random.normal(
|
||||
shape=(), seed=self._seed_generator
|
||||
)
|
||||
return x + self.tw + self.ntw
|
||||
|
||||
data = np.random.random((3, 4))
|
||||
layer = TestLayer()
|
||||
out = layer(data)
|
||||
layer1 = TestLayer()
|
||||
out1 = layer1(data)
|
||||
# Check that the layer is in fact deterministic
|
||||
self.assertAllClose(out, out1)
|
||||
|
||||
# Test stateless_call correctness
|
||||
layer2 = TestLayer()
|
||||
trainable_variables = layer2.trainable_variables
|
||||
non_trainable_variables = layer2.non_trainable_variables
|
||||
out2, non_trainable_variables = layer2.stateless_call(
|
||||
trainable_variables, non_trainable_variables, data
|
||||
)
|
||||
self.assertAllClose(out1, out2)
|
||||
self.assertEqual(
|
||||
len(layer1.non_trainable_variables), len(non_trainable_variables)
|
||||
)
|
||||
for ref_v, v in zip(
|
||||
layer1.non_trainable_variables, non_trainable_variables
|
||||
):
|
||||
self.assertAllClose(ref_v, v)
|
||||
|
||||
# Test with loss collection
|
||||
layer3 = TestLayer()
|
||||
trainable_variables = layer3.trainable_variables
|
||||
non_trainable_variables = layer3.non_trainable_variables
|
||||
out3, non_trainable_variables, losses = layer3.stateless_call(
|
||||
trainable_variables,
|
||||
non_trainable_variables,
|
||||
data,
|
||||
return_losses=True,
|
||||
)
|
||||
self.assertAllClose(out1, out3)
|
||||
for ref_v, v in zip(
|
||||
layer1.non_trainable_variables, non_trainable_variables
|
||||
):
|
||||
self.assertAllClose(ref_v, v)
|
||||
for ref_loss, loss in zip(layer1.losses, losses):
|
||||
self.assertAllClose(ref_loss, loss)
|
||||
|
Loading…
Reference in New Issue
Block a user