Add loss collection in stateless call as well stateless call tests
This commit is contained in:
parent
3dea0fc5b6
commit
c457e2a605
@ -28,7 +28,8 @@ def draw_seed(seed):
|
|||||||
from keras_core.backend import convert_to_tensor
|
from keras_core.backend import convert_to_tensor
|
||||||
|
|
||||||
if isinstance(seed, SeedGenerator):
|
if isinstance(seed, SeedGenerator):
|
||||||
new_seed_value = seed.state.value
|
# Use * 1 to create a copy
|
||||||
|
new_seed_value = seed.state.value * 1
|
||||||
seed.state.assign(
|
seed.state.assign(
|
||||||
seed.state + convert_to_tensor([0, 1], dtype="uint32")
|
seed.state + convert_to_tensor([0, 1], dtype="uint32")
|
||||||
)
|
)
|
||||||
|
@ -415,15 +415,58 @@ class Layer(Operation):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def stateless_call(
|
def stateless_call(
|
||||||
self, trainable_variables, non_trainable_variables, *args, **kwargs
|
self,
|
||||||
|
trainable_variables,
|
||||||
|
non_trainable_variables,
|
||||||
|
*args,
|
||||||
|
return_losses=False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# TODO: also handle losses
|
"""Call the layer without any side effects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trainable_variables: List of trainable variables of the model.
|
||||||
|
non_trainable_variables: List of non-trainable variables of the model.
|
||||||
|
*args: Positional argumets to be passed to `call()`.
|
||||||
|
return_losses: If `True`, `stateless_call()` will return the list of
|
||||||
|
losses created during `call()` as part of its return values.
|
||||||
|
**kwargs: Keyword arguments to be passed to `call()`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple. By default, returns `(outputs, non_trainable_variables)`.
|
||||||
|
If `return_losses = True`, then returns
|
||||||
|
`(outputs, non_trainable_variables, losses)`.
|
||||||
|
|
||||||
|
Note: `non_trainable_variables` include not only non-trainable weights
|
||||||
|
such as `BatchNormalization` statistics, but also RNG seed state
|
||||||
|
(if there are any random operations part of the layer, such as dropout),
|
||||||
|
and `Metric` state (if there are any metrics attached to the layer).
|
||||||
|
These are all elements of state of the layer.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = ...
|
||||||
|
data = ...
|
||||||
|
trainable_variables = model.trainable_variables
|
||||||
|
non_trainable_variables = model.non_trainable_variables
|
||||||
|
# Call the model with zero side effects
|
||||||
|
outputs, non_trainable_variables = model.stateless_call(
|
||||||
|
trainable_variables,
|
||||||
|
non_trainable_variables,
|
||||||
|
data,
|
||||||
|
)
|
||||||
|
# Attach the updated state to the model
|
||||||
|
# (until you do this, the model is still in its pre-call state).
|
||||||
|
for ref_var, value in zip(model.non_trainable_variables, non_trainable_variables):
|
||||||
|
ref_var.assign(value)
|
||||||
|
```
|
||||||
|
"""
|
||||||
self._check_super_called()
|
self._check_super_called()
|
||||||
|
|
||||||
if not self.built:
|
if not self.built:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"To call stateless_call, {self.__class__.__name__} must be "
|
f"To call stateless_call, {self.__class__.__name__} must be "
|
||||||
"built (i.e. its variables must have been already created). "
|
"built (i.e. its variables must have been already created). "
|
||||||
"You can build it by calling it on some data."
|
"You can build it by calling it on some data."
|
||||||
)
|
)
|
||||||
@ -452,7 +495,9 @@ class Layer(Operation):
|
|||||||
mapping = list(trainable_mapping) + list(non_trainable_mapping)
|
mapping = list(trainable_mapping) + list(non_trainable_mapping)
|
||||||
|
|
||||||
# Call in stateless scope
|
# Call in stateless scope
|
||||||
with backend.StatelessScope(state_mapping=mapping) as scope:
|
with backend.StatelessScope(
|
||||||
|
state_mapping=mapping, collect_losses=return_losses
|
||||||
|
) as scope:
|
||||||
outputs = self.call(*args, **kwargs)
|
outputs = self.call(*args, **kwargs)
|
||||||
|
|
||||||
# Gather updated non-trainable variables
|
# Gather updated non-trainable variables
|
||||||
@ -463,6 +508,9 @@ class Layer(Operation):
|
|||||||
non_trainable_variables.append(new_v)
|
non_trainable_variables.append(new_v)
|
||||||
else:
|
else:
|
||||||
non_trainable_variables.append(v)
|
non_trainable_variables.append(v)
|
||||||
|
|
||||||
|
if return_losses:
|
||||||
|
return outputs, non_trainable_variables, scope.losses[:]
|
||||||
return outputs, non_trainable_variables
|
return outputs, non_trainable_variables
|
||||||
|
|
||||||
def compute_output_spec(self, *args, **kwargs):
|
def compute_output_spec(self, *args, **kwargs):
|
||||||
|
@ -332,3 +332,71 @@ class LayerTest(testing.TestCase):
|
|||||||
x2._keras_mask = backend.numpy.ones((4,))
|
x2._keras_mask = backend.numpy.ones((4,))
|
||||||
layer((x1_1, x1_2), x2)
|
layer((x1_1, x1_2), x2)
|
||||||
layer(x1=(x1_1, x1_2), x2=x2)
|
layer(x1=(x1_1, x1_2), x2=x2)
|
||||||
|
|
||||||
|
def test_stateless_call(self):
|
||||||
|
class TestLayer(layers.Layer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._seed_generator = backend.random.SeedGenerator(1337)
|
||||||
|
self.ntw = self.add_weight(
|
||||||
|
shape=(),
|
||||||
|
initializer="zeros",
|
||||||
|
trainable=False,
|
||||||
|
)
|
||||||
|
self.tw = self.add_weight(
|
||||||
|
shape=(),
|
||||||
|
initializer="zeros",
|
||||||
|
trainable=True,
|
||||||
|
)
|
||||||
|
self.built = True
|
||||||
|
|
||||||
|
def call(self, x):
|
||||||
|
x = backend.convert_to_tensor(x, dtype="float32")
|
||||||
|
self.add_loss(ops.sum(x))
|
||||||
|
self.ntw.assign(ops.sum(x))
|
||||||
|
x = x + backend.random.normal(
|
||||||
|
shape=(), seed=self._seed_generator
|
||||||
|
)
|
||||||
|
return x + self.tw + self.ntw
|
||||||
|
|
||||||
|
data = np.random.random((3, 4))
|
||||||
|
layer = TestLayer()
|
||||||
|
out = layer(data)
|
||||||
|
layer1 = TestLayer()
|
||||||
|
out1 = layer1(data)
|
||||||
|
# Check that the layer is in fact deterministic
|
||||||
|
self.assertAllClose(out, out1)
|
||||||
|
|
||||||
|
# Test stateless_call correctness
|
||||||
|
layer2 = TestLayer()
|
||||||
|
trainable_variables = layer2.trainable_variables
|
||||||
|
non_trainable_variables = layer2.non_trainable_variables
|
||||||
|
out2, non_trainable_variables = layer2.stateless_call(
|
||||||
|
trainable_variables, non_trainable_variables, data
|
||||||
|
)
|
||||||
|
self.assertAllClose(out1, out2)
|
||||||
|
self.assertEqual(
|
||||||
|
len(layer1.non_trainable_variables), len(non_trainable_variables)
|
||||||
|
)
|
||||||
|
for ref_v, v in zip(
|
||||||
|
layer1.non_trainable_variables, non_trainable_variables
|
||||||
|
):
|
||||||
|
self.assertAllClose(ref_v, v)
|
||||||
|
|
||||||
|
# Test with loss collection
|
||||||
|
layer3 = TestLayer()
|
||||||
|
trainable_variables = layer3.trainable_variables
|
||||||
|
non_trainable_variables = layer3.non_trainable_variables
|
||||||
|
out3, non_trainable_variables, losses = layer3.stateless_call(
|
||||||
|
trainable_variables,
|
||||||
|
non_trainable_variables,
|
||||||
|
data,
|
||||||
|
return_losses=True,
|
||||||
|
)
|
||||||
|
self.assertAllClose(out1, out3)
|
||||||
|
for ref_v, v in zip(
|
||||||
|
layer1.non_trainable_variables, non_trainable_variables
|
||||||
|
):
|
||||||
|
self.assertAllClose(ref_v, v)
|
||||||
|
for ref_loss, loss in zip(layer1.losses, losses):
|
||||||
|
self.assertAllClose(ref_loss, loss)
|
||||||
|
Loading…
Reference in New Issue
Block a user