Implement masking.

This commit is contained in:
Francois Chollet 2023-04-28 13:49:38 -07:00
parent da55728a35
commit e5e3c9f833
2 changed files with 211 additions and 64 deletions

@ -36,9 +36,6 @@ from keras_core.operations.operation import Operation
from keras_core.utils import summary_utils
from keras_core.utils.tracking import Tracker
# TODO: cache all call signature processing.
# See layer_utils.CallFunctionSpec() in Keras.
@keras_core_export(["keras_core.Layer", "keras_core.layers.Layer"])
class Layer(Operation):
@ -277,8 +274,8 @@ class Layer(Operation):
self._supports_masking = value
@utils.default
def compute_mask(self, *args, **kwargs):
raise NotImplementedError
def compute_mask(self, inputs, previous_mask):
return previous_mask
def __call__(self, *args, **kwargs):
self._check_super_called()
@ -290,7 +287,7 @@ class Layer(Operation):
if isinstance(x, np.ndarray) or backend.is_tensor(x):
# TODO: only cast when float!
return backend.convert_to_tensor(x, dtype=self.compute_dtype)
# TODO: cast KerasTensor too
# TODO: cast KerasTensor too?
return x
args = nest.map_structure(maybe_convert, args)
@ -306,27 +303,28 @@ class Layer(Operation):
f"(of type {type(arg)})"
)
call_spec = CallSpec(self.call, args, kwargs)
# 4. Check input spec for 1st positional arg.
# TODO: consider extending this to all args and kwargs.
self._assert_input_compatibility(*args)
self._assert_input_compatibility(call_spec.first_arg)
######################################
###############
# Call build. #
self._maybe_build(*args, **kwargs)
# 5. Call build. #
self._maybe_build(call_spec)
###############
# Maintains info about the `Layer.call` stack.
call_context = self._get_call_context()
# Infer training value
# 6. Infer training value
# Training phase for `Layer.call` is set via (in order of priority):
# (1) The `training` argument passed to this `Layer.call`, if not None
# (2) The training argument of an outer `Layer.call`.
# (4) Any non-None default value for `training` in the call signature
# (5) False (treating the layer as if it's in inference)
arguments_dict = get_arguments_dict(self.call, *args, **kwargs)
training = arguments_dict.get("training", None)
training = call_spec.arguments_dict.get("training", None)
if training is None:
training = call_context.training
if training is None:
@ -337,10 +335,34 @@ class Layer(Operation):
if self._call_has_training_arg():
kwargs["training"] = training
# TODO: Populate mask argument(s)
# if self._mask_has_training_arg():
# 7. Populate mask argument(s)
if self.supports_masking:
if len(call_spec.tensor_arguments_dict) == 1:
if (
"mask" in call_spec.argument_names
and call_spec.arguments_dict["mask"] is None
):
arg_name = list(call_spec.tensor_arguments_dict.keys())[0]
only_tensor_arg = call_spec.tensor_arguments_dict[arg_name]
mask = nest.map_structure(
lambda x: getattr(x, "_keras_mask", None),
only_tensor_arg,
)
kwargs["mask"] = mask
elif len(call_spec.tensor_arguments_dict) > 1:
for k, v in call_spec.tensor_arguments_dict.items():
expected_mask_arg_name = f"{k}_mask"
if expected_mask_arg_name in call_spec.argument_names:
if (
call_spec.arguments_dict[expected_mask_arg_name]
is None
):
mask = nest.map_structure(
lambda x: getattr(x, "_keras_mask", None), v
)
kwargs[expected_mask_arg_name] = mask
# Call the layer.
# 8. Call the layer.
try:
with backend.name_scope(self.name):
if self.autocast and self.compute_dtype != self.variable_dtype:
@ -356,8 +378,16 @@ class Layer(Operation):
if backend.is_tensor(output):
self.add_loss(self.activity_regularizer(output))
# TODO: Set masks on outputs
# self._set_mask_metadata(inputs, outputs, previous_mask)
if self.supports_masking:
# Set masks on outputs,
# provided only the first positional input arg and its mask.
# TODO: consider extending this to all args and kwargs.
previous_mask = getattr(
call_spec.first_arg, "_keras_mask", None
)
self._set_mask_metadata(
call_spec.first_arg, outputs, previous_mask
)
finally:
# Destroy call context if we created it
self._maybe_reset_call_context()
@ -422,8 +452,8 @@ class Layer(Operation):
return super().compute_output_spec(*args, **kwargs)
else:
# Use compute_output_shape() to return the right output spec
arguments_dict = get_arguments_dict(self.call, *args, **kwargs)
shapes_dict = get_shapes_dict(arguments_dict)
call_spec = CallSpec(self.call, args, kwargs)
shapes_dict = get_shapes_dict(call_spec)
if len(shapes_dict) == 1:
# Single arg: pass it positionally
input_shape = tuple(shapes_dict.values())[0]
@ -553,10 +583,9 @@ class Layer(Operation):
)
return summary_utils.count_params(self.weights)
def _maybe_build(self, *args, **kwargs):
def _maybe_build(self, call_spec):
if not self.built:
arguments_dict = get_arguments_dict(self.call, *args, **kwargs)
shapes_dict = get_shapes_dict(arguments_dict)
shapes_dict = get_shapes_dict(call_spec)
self._build_shapes_dict = shapes_dict
failure = False
if len(shapes_dict) == 1:
@ -578,14 +607,32 @@ class Layer(Operation):
# and check that build() expects the right args.
check_build_signature(self.build, shapes_dict)
with backend.name_scope(self.name):
if utils.is_default(
self.build
) and might_have_unbuilt_state(self):
status = self._build_by_run_for_kwargs(shapes_dict)
if not status:
failure = True
if utils.is_default(self.build):
if might_have_unbuilt_state(self):
status = self._build_by_run_for_kwargs(shapes_dict)
if not status:
failure = True
else:
self.build(**shapes_dict)
run_build = True
build_args = set(
inspect.signature(self.build).parameters.keys()
)
for key in shapes_dict.keys():
if key not in build_args:
run_build = False
if run_build:
self.build(**shapes_dict)
else:
raise ValueError(
"In a layer with multiple tensor arguments "
"in call(), the build() method should accept "
"corresponding `*_shape` arguments, e.g. "
"if the call signature is `def call(self, x1, x2)` "
"then the build signature should be "
"`def build(self, x1_shape, x2_shape)`. "
"Keras will not build this layer automatically "
"since it does not conform to this."
)
if failure:
raise ValueError(
f"Layer '{self.name}' looks like it has "
@ -601,7 +648,7 @@ class Layer(Operation):
# Check input spec again (after build, since self.input_spec
# may have been updated
self._assert_input_compatibility(*args)
self._assert_input_compatibility(call_spec.first_arg)
def _build_by_run_for_single_pos_arg(self, input_shape):
# Case: all inputs are in the first arg (possibly nested).
@ -651,12 +698,16 @@ class Layer(Operation):
return False
def __repr__(self):
# TODO: improve
return f"<{self.__class__.__name__} name={self.name}>"
return (
f"<{self.__class__.__name__} "
f"name={self.name}, built={self.built}>"
)
def __str__(self):
# TODO: improve
return f"<{self.__class__.__name__} name={self.name}>"
return (
f"<{self.__class__.__name__} "
f"name={self.name}, built={self.built}>"
)
def __setattr__(self, name, value):
# Track Variables, Layers, Metrics
@ -672,10 +723,10 @@ class Layer(Operation):
"Go add it!"
)
def _assert_input_compatibility(self, *args):
if args and self.input_spec:
def _assert_input_compatibility(self, arg_0):
if self.input_spec:
input_spec.assert_input_compatibility(
self.input_spec, args[0], layer_name=self.name
self.input_spec, arg_0, layer_name=self.name
)
def _call_has_training_arg(self):
@ -732,12 +783,6 @@ class Layer(Operation):
return layers
def _set_mask_metadata(self, inputs, outputs, previous_mask):
# Many `Layer`s don't need to call `compute_mask`.
# This method is optimized to do as little work as needed for the common
# case.
if not self._supports_masking:
return
flat_outputs = nest.flatten(outputs)
mask_already_computed = all(
@ -752,10 +797,55 @@ class Layer(Operation):
flat_masks = nest.flatten(output_masks)
for tensor, mask in zip(flat_outputs, flat_masks):
tensor._keras_mask = mask
if getattr(tensor, "_keras_mask", None) is None:
tensor._keras_mask = mask
def get_arguments_dict(fn, *args, **kwargs):
def is_backend_tensor_or_symbolic(x):
return backend.is_tensor(x) or isinstance(x, backend.KerasTensor)
class CallSpec:
def __init__(self, call_fn, args, kwargs):
sig = inspect.signature(call_fn)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
arg_dict = {}
arg_names = []
tensor_arg_dict = {}
tensor_args = []
tensor_arg_names = []
nested_tensor_arg_names = []
for name, value in bound_args.arguments.items():
arg_dict[name] = value
arg_names.append(name)
if is_backend_tensor_or_symbolic(value):
tensor_args.append(value)
tensor_arg_names.append(name)
tensor_arg_dict[name] = value
elif nest.is_nested(value):
flat_values = nest.flatten(value)
if all(is_backend_tensor_or_symbolic(x) for x in flat_values):
tensor_args.append(value)
tensor_arg_names.append(name)
tensor_arg_dict[name] = value
nested_tensor_arg_names.append(name)
elif any(is_backend_tensor_or_symbolic(x) for x in flat_values):
raise ValueError(
"In a nested call() argument, "
"you cannot mix tensors and non-tensors. "
"Received invalid mixed argument: "
f"{name}={value}"
)
self.arguments_dict = arg_dict
self.argument_names = arg_names
self.tensor_arguments_dict = tensor_arg_dict
self.tensor_arguments_names = tensor_arg_names
self.nested_tensor_argument_names = nested_tensor_arg_names
self.first_arg = arg_dict[arg_names[0]]
def get_arguments_dict(fn, args, kwargs):
"""Return a dict mapping argument names to their values."""
sig = inspect.signature(fn)
bound_args = sig.bind(*args, **kwargs)
@ -765,37 +855,24 @@ def get_arguments_dict(fn, *args, **kwargs):
return arg_dict
def get_shapes_dict(arguments_dict):
def get_shapes_dict(call_spec):
"""Convert the call() arguments dict into a dict of input shape arguments.
Example:
```
>>> get_shapes_dict(
{"input_a": KerasTensor(shape=(2, 3)), "training": False})
>>> get_shapes_dict(call_spec)
{"input_a_shape": (2, 3)}
```
"""
shapes_dict = {}
for k, v in arguments_dict.items():
if isinstance(v, KerasTensor) or backend.is_tensor(v):
shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape)
elif nest.is_nested(v):
flat = nest.flatten(v)
if any(
isinstance(x, KerasTensor) or backend.is_tensor(x) for x in flat
):
if not all(
isinstance(x, KerasTensor) or backend.is_tensor(x)
for x in flat
):
raise ValueError(
"You cannot mix tensors and non-tensors in a nested "
f"argument. Invalid argument: {k}={v}"
)
for k, v in call_spec.tensor_arguments_dict.items():
if k in call_spec.nested_tensor_argument_names:
shapes_dict[f"{k}_shape"] = nest.map_structure(
lambda x: backend.standardize_shape(x.shape), v
)
else:
shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape)
return shapes_dict

@ -262,3 +262,73 @@ class LayerTest(testing.TestCase):
self.assertEqual(layer.variable_dtype, "float32")
self.assertEqual(y.dtype.name, "float16")
self.assertEqual(layer.kernel.dtype, "float32")
def test_masking(self):
class BasicMaskedLayer(layers.Layer):
def __init__(self):
super().__init__()
self.supports_masking = True
def call(self, x, mask=None):
assert mask is not None
return x
layer = BasicMaskedLayer()
x = backend.numpy.ones((4, 4))
x._keras_mask = backend.numpy.ones((4,))
layer(x)
class NestedInputMaskedLayer(layers.Layer):
def __init__(self):
super().__init__()
self.supports_masking = True
def call(self, x, mask=None):
assert isinstance(x, list)
assert len(x) == 2
assert isinstance(mask, list)
assert len(mask) == 2
return x
layer = NestedInputMaskedLayer()
x1 = backend.numpy.ones((4, 4))
x1._keras_mask = backend.numpy.ones((4,))
x2 = backend.numpy.ones((4, 4))
x2._keras_mask = backend.numpy.ones((4,))
layer([x1, x2])
class PositionalInputsMaskedLayer(layers.Layer):
def __init__(self):
super().__init__()
self.supports_masking = True
def call(self, x1, x2, x1_mask=None, x2_mask=None):
assert x1_mask is not None
assert x2_mask is not None
return x1 + x2
layer = PositionalInputsMaskedLayer()
layer(x1, x2)
layer(x1=x1, x2=x2)
class PositionalNestedInputsMaskedLayer(layers.Layer):
def __init__(self):
super().__init__()
self.supports_masking = True
def call(self, x1, x2, x1_mask=None, x2_mask=None):
assert isinstance(x1, tuple)
assert x1_mask is not None
assert x2_mask is not None
assert isinstance(x1_mask, tuple)
return x1[0] + x1[1] + x2
layer = PositionalNestedInputsMaskedLayer()
x1_1 = backend.numpy.ones((4, 4))
x1_1._keras_mask = backend.numpy.ones((4,))
x1_2 = backend.numpy.ones((4, 4))
x1_2._keras_mask = backend.numpy.ones((4,))
x2 = backend.numpy.ones((4, 4))
x2._keras_mask = backend.numpy.ones((4,))
layer((x1_1, x1_2), x2)
layer(x1=(x1_1, x1_2), x2=x2)