Implement masking.
This commit is contained in:
parent
da55728a35
commit
e5e3c9f833
@ -36,9 +36,6 @@ from keras_core.operations.operation import Operation
|
||||
from keras_core.utils import summary_utils
|
||||
from keras_core.utils.tracking import Tracker
|
||||
|
||||
# TODO: cache all call signature processing.
|
||||
# See layer_utils.CallFunctionSpec() in Keras.
|
||||
|
||||
|
||||
@keras_core_export(["keras_core.Layer", "keras_core.layers.Layer"])
|
||||
class Layer(Operation):
|
||||
@ -277,8 +274,8 @@ class Layer(Operation):
|
||||
self._supports_masking = value
|
||||
|
||||
@utils.default
|
||||
def compute_mask(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
def compute_mask(self, inputs, previous_mask):
|
||||
return previous_mask
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._check_super_called()
|
||||
@ -290,7 +287,7 @@ class Layer(Operation):
|
||||
if isinstance(x, np.ndarray) or backend.is_tensor(x):
|
||||
# TODO: only cast when float!
|
||||
return backend.convert_to_tensor(x, dtype=self.compute_dtype)
|
||||
# TODO: cast KerasTensor too
|
||||
# TODO: cast KerasTensor too?
|
||||
return x
|
||||
|
||||
args = nest.map_structure(maybe_convert, args)
|
||||
@ -306,27 +303,28 @@ class Layer(Operation):
|
||||
f"(of type {type(arg)})"
|
||||
)
|
||||
|
||||
call_spec = CallSpec(self.call, args, kwargs)
|
||||
|
||||
# 4. Check input spec for 1st positional arg.
|
||||
# TODO: consider extending this to all args and kwargs.
|
||||
self._assert_input_compatibility(*args)
|
||||
self._assert_input_compatibility(call_spec.first_arg)
|
||||
######################################
|
||||
|
||||
###############
|
||||
# Call build. #
|
||||
self._maybe_build(*args, **kwargs)
|
||||
# 5. Call build. #
|
||||
self._maybe_build(call_spec)
|
||||
###############
|
||||
|
||||
# Maintains info about the `Layer.call` stack.
|
||||
call_context = self._get_call_context()
|
||||
|
||||
# Infer training value
|
||||
# 6. Infer training value
|
||||
# Training phase for `Layer.call` is set via (in order of priority):
|
||||
# (1) The `training` argument passed to this `Layer.call`, if not None
|
||||
# (2) The training argument of an outer `Layer.call`.
|
||||
# (4) Any non-None default value for `training` in the call signature
|
||||
# (5) False (treating the layer as if it's in inference)
|
||||
arguments_dict = get_arguments_dict(self.call, *args, **kwargs)
|
||||
training = arguments_dict.get("training", None)
|
||||
training = call_spec.arguments_dict.get("training", None)
|
||||
if training is None:
|
||||
training = call_context.training
|
||||
if training is None:
|
||||
@ -337,10 +335,34 @@ class Layer(Operation):
|
||||
if self._call_has_training_arg():
|
||||
kwargs["training"] = training
|
||||
|
||||
# TODO: Populate mask argument(s)
|
||||
# if self._mask_has_training_arg():
|
||||
# 7. Populate mask argument(s)
|
||||
if self.supports_masking:
|
||||
if len(call_spec.tensor_arguments_dict) == 1:
|
||||
if (
|
||||
"mask" in call_spec.argument_names
|
||||
and call_spec.arguments_dict["mask"] is None
|
||||
):
|
||||
arg_name = list(call_spec.tensor_arguments_dict.keys())[0]
|
||||
only_tensor_arg = call_spec.tensor_arguments_dict[arg_name]
|
||||
mask = nest.map_structure(
|
||||
lambda x: getattr(x, "_keras_mask", None),
|
||||
only_tensor_arg,
|
||||
)
|
||||
kwargs["mask"] = mask
|
||||
elif len(call_spec.tensor_arguments_dict) > 1:
|
||||
for k, v in call_spec.tensor_arguments_dict.items():
|
||||
expected_mask_arg_name = f"{k}_mask"
|
||||
if expected_mask_arg_name in call_spec.argument_names:
|
||||
if (
|
||||
call_spec.arguments_dict[expected_mask_arg_name]
|
||||
is None
|
||||
):
|
||||
mask = nest.map_structure(
|
||||
lambda x: getattr(x, "_keras_mask", None), v
|
||||
)
|
||||
kwargs[expected_mask_arg_name] = mask
|
||||
|
||||
# Call the layer.
|
||||
# 8. Call the layer.
|
||||
try:
|
||||
with backend.name_scope(self.name):
|
||||
if self.autocast and self.compute_dtype != self.variable_dtype:
|
||||
@ -356,8 +378,16 @@ class Layer(Operation):
|
||||
if backend.is_tensor(output):
|
||||
self.add_loss(self.activity_regularizer(output))
|
||||
|
||||
# TODO: Set masks on outputs
|
||||
# self._set_mask_metadata(inputs, outputs, previous_mask)
|
||||
if self.supports_masking:
|
||||
# Set masks on outputs,
|
||||
# provided only the first positional input arg and its mask.
|
||||
# TODO: consider extending this to all args and kwargs.
|
||||
previous_mask = getattr(
|
||||
call_spec.first_arg, "_keras_mask", None
|
||||
)
|
||||
self._set_mask_metadata(
|
||||
call_spec.first_arg, outputs, previous_mask
|
||||
)
|
||||
finally:
|
||||
# Destroy call context if we created it
|
||||
self._maybe_reset_call_context()
|
||||
@ -422,8 +452,8 @@ class Layer(Operation):
|
||||
return super().compute_output_spec(*args, **kwargs)
|
||||
else:
|
||||
# Use compute_output_shape() to return the right output spec
|
||||
arguments_dict = get_arguments_dict(self.call, *args, **kwargs)
|
||||
shapes_dict = get_shapes_dict(arguments_dict)
|
||||
call_spec = CallSpec(self.call, args, kwargs)
|
||||
shapes_dict = get_shapes_dict(call_spec)
|
||||
if len(shapes_dict) == 1:
|
||||
# Single arg: pass it positionally
|
||||
input_shape = tuple(shapes_dict.values())[0]
|
||||
@ -553,10 +583,9 @@ class Layer(Operation):
|
||||
)
|
||||
return summary_utils.count_params(self.weights)
|
||||
|
||||
def _maybe_build(self, *args, **kwargs):
|
||||
def _maybe_build(self, call_spec):
|
||||
if not self.built:
|
||||
arguments_dict = get_arguments_dict(self.call, *args, **kwargs)
|
||||
shapes_dict = get_shapes_dict(arguments_dict)
|
||||
shapes_dict = get_shapes_dict(call_spec)
|
||||
self._build_shapes_dict = shapes_dict
|
||||
failure = False
|
||||
if len(shapes_dict) == 1:
|
||||
@ -578,14 +607,32 @@ class Layer(Operation):
|
||||
# and check that build() expects the right args.
|
||||
check_build_signature(self.build, shapes_dict)
|
||||
with backend.name_scope(self.name):
|
||||
if utils.is_default(
|
||||
self.build
|
||||
) and might_have_unbuilt_state(self):
|
||||
status = self._build_by_run_for_kwargs(shapes_dict)
|
||||
if not status:
|
||||
failure = True
|
||||
if utils.is_default(self.build):
|
||||
if might_have_unbuilt_state(self):
|
||||
status = self._build_by_run_for_kwargs(shapes_dict)
|
||||
if not status:
|
||||
failure = True
|
||||
else:
|
||||
self.build(**shapes_dict)
|
||||
run_build = True
|
||||
build_args = set(
|
||||
inspect.signature(self.build).parameters.keys()
|
||||
)
|
||||
for key in shapes_dict.keys():
|
||||
if key not in build_args:
|
||||
run_build = False
|
||||
if run_build:
|
||||
self.build(**shapes_dict)
|
||||
else:
|
||||
raise ValueError(
|
||||
"In a layer with multiple tensor arguments "
|
||||
"in call(), the build() method should accept "
|
||||
"corresponding `*_shape` arguments, e.g. "
|
||||
"if the call signature is `def call(self, x1, x2)` "
|
||||
"then the build signature should be "
|
||||
"`def build(self, x1_shape, x2_shape)`. "
|
||||
"Keras will not build this layer automatically "
|
||||
"since it does not conform to this."
|
||||
)
|
||||
if failure:
|
||||
raise ValueError(
|
||||
f"Layer '{self.name}' looks like it has "
|
||||
@ -601,7 +648,7 @@ class Layer(Operation):
|
||||
|
||||
# Check input spec again (after build, since self.input_spec
|
||||
# may have been updated
|
||||
self._assert_input_compatibility(*args)
|
||||
self._assert_input_compatibility(call_spec.first_arg)
|
||||
|
||||
def _build_by_run_for_single_pos_arg(self, input_shape):
|
||||
# Case: all inputs are in the first arg (possibly nested).
|
||||
@ -651,12 +698,16 @@ class Layer(Operation):
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
# TODO: improve
|
||||
return f"<{self.__class__.__name__} name={self.name}>"
|
||||
return (
|
||||
f"<{self.__class__.__name__} "
|
||||
f"name={self.name}, built={self.built}>"
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
# TODO: improve
|
||||
return f"<{self.__class__.__name__} name={self.name}>"
|
||||
return (
|
||||
f"<{self.__class__.__name__} "
|
||||
f"name={self.name}, built={self.built}>"
|
||||
)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# Track Variables, Layers, Metrics
|
||||
@ -672,10 +723,10 @@ class Layer(Operation):
|
||||
"Go add it!"
|
||||
)
|
||||
|
||||
def _assert_input_compatibility(self, *args):
|
||||
if args and self.input_spec:
|
||||
def _assert_input_compatibility(self, arg_0):
|
||||
if self.input_spec:
|
||||
input_spec.assert_input_compatibility(
|
||||
self.input_spec, args[0], layer_name=self.name
|
||||
self.input_spec, arg_0, layer_name=self.name
|
||||
)
|
||||
|
||||
def _call_has_training_arg(self):
|
||||
@ -732,12 +783,6 @@ class Layer(Operation):
|
||||
return layers
|
||||
|
||||
def _set_mask_metadata(self, inputs, outputs, previous_mask):
|
||||
# Many `Layer`s don't need to call `compute_mask`.
|
||||
# This method is optimized to do as little work as needed for the common
|
||||
# case.
|
||||
if not self._supports_masking:
|
||||
return
|
||||
|
||||
flat_outputs = nest.flatten(outputs)
|
||||
|
||||
mask_already_computed = all(
|
||||
@ -752,10 +797,55 @@ class Layer(Operation):
|
||||
|
||||
flat_masks = nest.flatten(output_masks)
|
||||
for tensor, mask in zip(flat_outputs, flat_masks):
|
||||
tensor._keras_mask = mask
|
||||
if getattr(tensor, "_keras_mask", None) is None:
|
||||
tensor._keras_mask = mask
|
||||
|
||||
|
||||
def get_arguments_dict(fn, *args, **kwargs):
|
||||
def is_backend_tensor_or_symbolic(x):
|
||||
return backend.is_tensor(x) or isinstance(x, backend.KerasTensor)
|
||||
|
||||
|
||||
class CallSpec:
|
||||
def __init__(self, call_fn, args, kwargs):
|
||||
sig = inspect.signature(call_fn)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
arg_dict = {}
|
||||
arg_names = []
|
||||
tensor_arg_dict = {}
|
||||
tensor_args = []
|
||||
tensor_arg_names = []
|
||||
nested_tensor_arg_names = []
|
||||
for name, value in bound_args.arguments.items():
|
||||
arg_dict[name] = value
|
||||
arg_names.append(name)
|
||||
if is_backend_tensor_or_symbolic(value):
|
||||
tensor_args.append(value)
|
||||
tensor_arg_names.append(name)
|
||||
tensor_arg_dict[name] = value
|
||||
elif nest.is_nested(value):
|
||||
flat_values = nest.flatten(value)
|
||||
if all(is_backend_tensor_or_symbolic(x) for x in flat_values):
|
||||
tensor_args.append(value)
|
||||
tensor_arg_names.append(name)
|
||||
tensor_arg_dict[name] = value
|
||||
nested_tensor_arg_names.append(name)
|
||||
elif any(is_backend_tensor_or_symbolic(x) for x in flat_values):
|
||||
raise ValueError(
|
||||
"In a nested call() argument, "
|
||||
"you cannot mix tensors and non-tensors. "
|
||||
"Received invalid mixed argument: "
|
||||
f"{name}={value}"
|
||||
)
|
||||
self.arguments_dict = arg_dict
|
||||
self.argument_names = arg_names
|
||||
self.tensor_arguments_dict = tensor_arg_dict
|
||||
self.tensor_arguments_names = tensor_arg_names
|
||||
self.nested_tensor_argument_names = nested_tensor_arg_names
|
||||
self.first_arg = arg_dict[arg_names[0]]
|
||||
|
||||
|
||||
def get_arguments_dict(fn, args, kwargs):
|
||||
"""Return a dict mapping argument names to their values."""
|
||||
sig = inspect.signature(fn)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
@ -765,37 +855,24 @@ def get_arguments_dict(fn, *args, **kwargs):
|
||||
return arg_dict
|
||||
|
||||
|
||||
def get_shapes_dict(arguments_dict):
|
||||
def get_shapes_dict(call_spec):
|
||||
"""Convert the call() arguments dict into a dict of input shape arguments.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
>>> get_shapes_dict(
|
||||
{"input_a": KerasTensor(shape=(2, 3)), "training": False})
|
||||
>>> get_shapes_dict(call_spec)
|
||||
{"input_a_shape": (2, 3)}
|
||||
```
|
||||
"""
|
||||
shapes_dict = {}
|
||||
for k, v in arguments_dict.items():
|
||||
if isinstance(v, KerasTensor) or backend.is_tensor(v):
|
||||
shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape)
|
||||
elif nest.is_nested(v):
|
||||
flat = nest.flatten(v)
|
||||
if any(
|
||||
isinstance(x, KerasTensor) or backend.is_tensor(x) for x in flat
|
||||
):
|
||||
if not all(
|
||||
isinstance(x, KerasTensor) or backend.is_tensor(x)
|
||||
for x in flat
|
||||
):
|
||||
raise ValueError(
|
||||
"You cannot mix tensors and non-tensors in a nested "
|
||||
f"argument. Invalid argument: {k}={v}"
|
||||
)
|
||||
for k, v in call_spec.tensor_arguments_dict.items():
|
||||
if k in call_spec.nested_tensor_argument_names:
|
||||
shapes_dict[f"{k}_shape"] = nest.map_structure(
|
||||
lambda x: backend.standardize_shape(x.shape), v
|
||||
)
|
||||
else:
|
||||
shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape)
|
||||
return shapes_dict
|
||||
|
||||
|
||||
|
@ -262,3 +262,73 @@ class LayerTest(testing.TestCase):
|
||||
self.assertEqual(layer.variable_dtype, "float32")
|
||||
self.assertEqual(y.dtype.name, "float16")
|
||||
self.assertEqual(layer.kernel.dtype, "float32")
|
||||
|
||||
def test_masking(self):
|
||||
class BasicMaskedLayer(layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supports_masking = True
|
||||
|
||||
def call(self, x, mask=None):
|
||||
assert mask is not None
|
||||
return x
|
||||
|
||||
layer = BasicMaskedLayer()
|
||||
x = backend.numpy.ones((4, 4))
|
||||
x._keras_mask = backend.numpy.ones((4,))
|
||||
layer(x)
|
||||
|
||||
class NestedInputMaskedLayer(layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supports_masking = True
|
||||
|
||||
def call(self, x, mask=None):
|
||||
assert isinstance(x, list)
|
||||
assert len(x) == 2
|
||||
assert isinstance(mask, list)
|
||||
assert len(mask) == 2
|
||||
return x
|
||||
|
||||
layer = NestedInputMaskedLayer()
|
||||
x1 = backend.numpy.ones((4, 4))
|
||||
x1._keras_mask = backend.numpy.ones((4,))
|
||||
x2 = backend.numpy.ones((4, 4))
|
||||
x2._keras_mask = backend.numpy.ones((4,))
|
||||
layer([x1, x2])
|
||||
|
||||
class PositionalInputsMaskedLayer(layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supports_masking = True
|
||||
|
||||
def call(self, x1, x2, x1_mask=None, x2_mask=None):
|
||||
assert x1_mask is not None
|
||||
assert x2_mask is not None
|
||||
return x1 + x2
|
||||
|
||||
layer = PositionalInputsMaskedLayer()
|
||||
layer(x1, x2)
|
||||
layer(x1=x1, x2=x2)
|
||||
|
||||
class PositionalNestedInputsMaskedLayer(layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supports_masking = True
|
||||
|
||||
def call(self, x1, x2, x1_mask=None, x2_mask=None):
|
||||
assert isinstance(x1, tuple)
|
||||
assert x1_mask is not None
|
||||
assert x2_mask is not None
|
||||
assert isinstance(x1_mask, tuple)
|
||||
return x1[0] + x1[1] + x2
|
||||
|
||||
layer = PositionalNestedInputsMaskedLayer()
|
||||
x1_1 = backend.numpy.ones((4, 4))
|
||||
x1_1._keras_mask = backend.numpy.ones((4,))
|
||||
x1_2 = backend.numpy.ones((4, 4))
|
||||
x1_2._keras_mask = backend.numpy.ones((4,))
|
||||
x2 = backend.numpy.ones((4, 4))
|
||||
x2._keras_mask = backend.numpy.ones((4,))
|
||||
layer((x1_1, x1_2), x2)
|
||||
layer(x1=(x1_1, x1_2), x2=x2)
|
||||
|
Loading…
Reference in New Issue
Block a user