Allow training
in any layer call instead of erroring out. (#291)
* Added testing * Readd traceback * Format * Format * Update layer.py
This commit is contained in:
parent
90baabee5d
commit
d2d910b2a7
@ -1139,7 +1139,17 @@ def is_backend_tensor_or_symbolic(x):
|
||||
class CallSpec:
|
||||
def __init__(self, call_fn, args, kwargs):
|
||||
sig = inspect.signature(call_fn)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
|
||||
# `training` and `mask` are special kwargs that are always available in
|
||||
# a layer, if user specifies them in their call without adding to spec,
|
||||
# we remove them to be able to bind variables. User is not using
|
||||
# `training` anyway so we can ignore.
|
||||
# TODO: If necessary use workaround for `mask`
|
||||
if "training" in kwargs and "training" not in sig.parameters:
|
||||
kwargs.pop("training")
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
else:
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
self.user_arguments_dict = {
|
||||
k: v for k, v in bound_args.arguments.items()
|
||||
}
|
||||
|
@ -627,6 +627,21 @@ class LayerTest(testing.TestCase):
|
||||
|
||||
MatchingArguments()(foo, bar)
|
||||
|
||||
def test_training_arg_not_specified(self):
|
||||
class NoTrainingSpecified(layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def build(self, input_shape):
|
||||
self.activation = layers.Activation("linear")
|
||||
|
||||
def call(self, inputs):
|
||||
return self.activation(inputs)
|
||||
|
||||
layer = NoTrainingSpecified()
|
||||
inputs = ops.random.uniform(shape=(1, 100, 100, 3))
|
||||
layer(inputs, training=True)
|
||||
|
||||
def test_tracker_locking(self):
|
||||
class BadLayer(layers.Layer):
|
||||
def call(self, x):
|
||||
|
Loading…
Reference in New Issue
Block a user