Allow training in any layer call instead of erroring out. (#291)

* Added testing

* Readd traceback

* Format

* Format

* Update layer.py
This commit is contained in:
Gabriel Rasskin 2023-06-08 13:15:47 -04:00 committed by Francois Chollet
parent 90baabee5d
commit d2d910b2a7
2 changed files with 26 additions and 1 deletions

@ -1139,7 +1139,17 @@ def is_backend_tensor_or_symbolic(x):
class CallSpec:
def __init__(self, call_fn, args, kwargs):
sig = inspect.signature(call_fn)
bound_args = sig.bind(*args, **kwargs)
# `training` and `mask` are special kwargs that are always available in
# a layer, if user specifies them in their call without adding to spec,
# we remove them to be able to bind variables. User is not using
# `training` anyway so we can ignore.
# TODO: If necessary use workaround for `mask`
if "training" in kwargs and "training" not in sig.parameters:
kwargs.pop("training")
bound_args = sig.bind(*args, **kwargs)
else:
bound_args = sig.bind(*args, **kwargs)
self.user_arguments_dict = {
k: v for k, v in bound_args.arguments.items()
}

@ -627,6 +627,21 @@ class LayerTest(testing.TestCase):
MatchingArguments()(foo, bar)
def test_training_arg_not_specified(self):
class NoTrainingSpecified(layers.Layer):
def __init__(self):
super().__init__()
def build(self, input_shape):
self.activation = layers.Activation("linear")
def call(self, inputs):
return self.activation(inputs)
layer = NoTrainingSpecified()
inputs = ops.random.uniform(shape=(1, 100, 100, 3))
layer(inputs, training=True)
def test_tracker_locking(self):
class BadLayer(layers.Layer):
def call(self, x):