From e5e3c9f833ca1d4ef97bd3d65d3480df9331c584 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Fri, 28 Apr 2023 13:49:38 -0700 Subject: [PATCH] Implement masking. --- keras_core/layers/layer.py | 205 ++++++++++++++++++++++---------- keras_core/layers/layer_test.py | 70 +++++++++++ 2 files changed, 211 insertions(+), 64 deletions(-) diff --git a/keras_core/layers/layer.py b/keras_core/layers/layer.py index cab4c8369..44c72bd1d 100644 --- a/keras_core/layers/layer.py +++ b/keras_core/layers/layer.py @@ -36,9 +36,6 @@ from keras_core.operations.operation import Operation from keras_core.utils import summary_utils from keras_core.utils.tracking import Tracker -# TODO: cache all call signature processing. -# See layer_utils.CallFunctionSpec() in Keras. - @keras_core_export(["keras_core.Layer", "keras_core.layers.Layer"]) class Layer(Operation): @@ -277,8 +274,8 @@ class Layer(Operation): self._supports_masking = value @utils.default - def compute_mask(self, *args, **kwargs): - raise NotImplementedError + def compute_mask(self, inputs, previous_mask): + return previous_mask def __call__(self, *args, **kwargs): self._check_super_called() @@ -290,7 +287,7 @@ class Layer(Operation): if isinstance(x, np.ndarray) or backend.is_tensor(x): # TODO: only cast when float! return backend.convert_to_tensor(x, dtype=self.compute_dtype) - # TODO: cast KerasTensor too + # TODO: cast KerasTensor too? return x args = nest.map_structure(maybe_convert, args) @@ -306,27 +303,28 @@ class Layer(Operation): f"(of type {type(arg)})" ) + call_spec = CallSpec(self.call, args, kwargs) + # 4. Check input spec for 1st positional arg. # TODO: consider extending this to all args and kwargs. - self._assert_input_compatibility(*args) + self._assert_input_compatibility(call_spec.first_arg) ###################################### ############### - # Call build. # - self._maybe_build(*args, **kwargs) + # 5. Call build. # + self._maybe_build(call_spec) ############### # Maintains info about the `Layer.call` stack. call_context = self._get_call_context() - # Infer training value + # 6. Infer training value # Training phase for `Layer.call` is set via (in order of priority): # (1) The `training` argument passed to this `Layer.call`, if not None # (2) The training argument of an outer `Layer.call`. # (4) Any non-None default value for `training` in the call signature # (5) False (treating the layer as if it's in inference) - arguments_dict = get_arguments_dict(self.call, *args, **kwargs) - training = arguments_dict.get("training", None) + training = call_spec.arguments_dict.get("training", None) if training is None: training = call_context.training if training is None: @@ -337,10 +335,34 @@ class Layer(Operation): if self._call_has_training_arg(): kwargs["training"] = training - # TODO: Populate mask argument(s) - # if self._mask_has_training_arg(): + # 7. Populate mask argument(s) + if self.supports_masking: + if len(call_spec.tensor_arguments_dict) == 1: + if ( + "mask" in call_spec.argument_names + and call_spec.arguments_dict["mask"] is None + ): + arg_name = list(call_spec.tensor_arguments_dict.keys())[0] + only_tensor_arg = call_spec.tensor_arguments_dict[arg_name] + mask = nest.map_structure( + lambda x: getattr(x, "_keras_mask", None), + only_tensor_arg, + ) + kwargs["mask"] = mask + elif len(call_spec.tensor_arguments_dict) > 1: + for k, v in call_spec.tensor_arguments_dict.items(): + expected_mask_arg_name = f"{k}_mask" + if expected_mask_arg_name in call_spec.argument_names: + if ( + call_spec.arguments_dict[expected_mask_arg_name] + is None + ): + mask = nest.map_structure( + lambda x: getattr(x, "_keras_mask", None), v + ) + kwargs[expected_mask_arg_name] = mask - # Call the layer. + # 8. Call the layer. try: with backend.name_scope(self.name): if self.autocast and self.compute_dtype != self.variable_dtype: @@ -356,8 +378,16 @@ class Layer(Operation): if backend.is_tensor(output): self.add_loss(self.activity_regularizer(output)) - # TODO: Set masks on outputs - # self._set_mask_metadata(inputs, outputs, previous_mask) + if self.supports_masking: + # Set masks on outputs, + # provided only the first positional input arg and its mask. + # TODO: consider extending this to all args and kwargs. + previous_mask = getattr( + call_spec.first_arg, "_keras_mask", None + ) + self._set_mask_metadata( + call_spec.first_arg, outputs, previous_mask + ) finally: # Destroy call context if we created it self._maybe_reset_call_context() @@ -422,8 +452,8 @@ class Layer(Operation): return super().compute_output_spec(*args, **kwargs) else: # Use compute_output_shape() to return the right output spec - arguments_dict = get_arguments_dict(self.call, *args, **kwargs) - shapes_dict = get_shapes_dict(arguments_dict) + call_spec = CallSpec(self.call, args, kwargs) + shapes_dict = get_shapes_dict(call_spec) if len(shapes_dict) == 1: # Single arg: pass it positionally input_shape = tuple(shapes_dict.values())[0] @@ -553,10 +583,9 @@ class Layer(Operation): ) return summary_utils.count_params(self.weights) - def _maybe_build(self, *args, **kwargs): + def _maybe_build(self, call_spec): if not self.built: - arguments_dict = get_arguments_dict(self.call, *args, **kwargs) - shapes_dict = get_shapes_dict(arguments_dict) + shapes_dict = get_shapes_dict(call_spec) self._build_shapes_dict = shapes_dict failure = False if len(shapes_dict) == 1: @@ -578,14 +607,32 @@ class Layer(Operation): # and check that build() expects the right args. check_build_signature(self.build, shapes_dict) with backend.name_scope(self.name): - if utils.is_default( - self.build - ) and might_have_unbuilt_state(self): - status = self._build_by_run_for_kwargs(shapes_dict) - if not status: - failure = True + if utils.is_default(self.build): + if might_have_unbuilt_state(self): + status = self._build_by_run_for_kwargs(shapes_dict) + if not status: + failure = True else: - self.build(**shapes_dict) + run_build = True + build_args = set( + inspect.signature(self.build).parameters.keys() + ) + for key in shapes_dict.keys(): + if key not in build_args: + run_build = False + if run_build: + self.build(**shapes_dict) + else: + raise ValueError( + "In a layer with multiple tensor arguments " + "in call(), the build() method should accept " + "corresponding `*_shape` arguments, e.g. " + "if the call signature is `def call(self, x1, x2)` " + "then the build signature should be " + "`def build(self, x1_shape, x2_shape)`. " + "Keras will not build this layer automatically " + "since it does not conform to this." + ) if failure: raise ValueError( f"Layer '{self.name}' looks like it has " @@ -601,7 +648,7 @@ class Layer(Operation): # Check input spec again (after build, since self.input_spec # may have been updated - self._assert_input_compatibility(*args) + self._assert_input_compatibility(call_spec.first_arg) def _build_by_run_for_single_pos_arg(self, input_shape): # Case: all inputs are in the first arg (possibly nested). @@ -651,12 +698,16 @@ class Layer(Operation): return False def __repr__(self): - # TODO: improve - return f"<{self.__class__.__name__} name={self.name}>" + return ( + f"<{self.__class__.__name__} " + f"name={self.name}, built={self.built}>" + ) def __str__(self): - # TODO: improve - return f"<{self.__class__.__name__} name={self.name}>" + return ( + f"<{self.__class__.__name__} " + f"name={self.name}, built={self.built}>" + ) def __setattr__(self, name, value): # Track Variables, Layers, Metrics @@ -672,10 +723,10 @@ class Layer(Operation): "Go add it!" ) - def _assert_input_compatibility(self, *args): - if args and self.input_spec: + def _assert_input_compatibility(self, arg_0): + if self.input_spec: input_spec.assert_input_compatibility( - self.input_spec, args[0], layer_name=self.name + self.input_spec, arg_0, layer_name=self.name ) def _call_has_training_arg(self): @@ -732,12 +783,6 @@ class Layer(Operation): return layers def _set_mask_metadata(self, inputs, outputs, previous_mask): - # Many `Layer`s don't need to call `compute_mask`. - # This method is optimized to do as little work as needed for the common - # case. - if not self._supports_masking: - return - flat_outputs = nest.flatten(outputs) mask_already_computed = all( @@ -752,10 +797,55 @@ class Layer(Operation): flat_masks = nest.flatten(output_masks) for tensor, mask in zip(flat_outputs, flat_masks): - tensor._keras_mask = mask + if getattr(tensor, "_keras_mask", None) is None: + tensor._keras_mask = mask -def get_arguments_dict(fn, *args, **kwargs): +def is_backend_tensor_or_symbolic(x): + return backend.is_tensor(x) or isinstance(x, backend.KerasTensor) + + +class CallSpec: + def __init__(self, call_fn, args, kwargs): + sig = inspect.signature(call_fn) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + arg_dict = {} + arg_names = [] + tensor_arg_dict = {} + tensor_args = [] + tensor_arg_names = [] + nested_tensor_arg_names = [] + for name, value in bound_args.arguments.items(): + arg_dict[name] = value + arg_names.append(name) + if is_backend_tensor_or_symbolic(value): + tensor_args.append(value) + tensor_arg_names.append(name) + tensor_arg_dict[name] = value + elif nest.is_nested(value): + flat_values = nest.flatten(value) + if all(is_backend_tensor_or_symbolic(x) for x in flat_values): + tensor_args.append(value) + tensor_arg_names.append(name) + tensor_arg_dict[name] = value + nested_tensor_arg_names.append(name) + elif any(is_backend_tensor_or_symbolic(x) for x in flat_values): + raise ValueError( + "In a nested call() argument, " + "you cannot mix tensors and non-tensors. " + "Received invalid mixed argument: " + f"{name}={value}" + ) + self.arguments_dict = arg_dict + self.argument_names = arg_names + self.tensor_arguments_dict = tensor_arg_dict + self.tensor_arguments_names = tensor_arg_names + self.nested_tensor_argument_names = nested_tensor_arg_names + self.first_arg = arg_dict[arg_names[0]] + + +def get_arguments_dict(fn, args, kwargs): """Return a dict mapping argument names to their values.""" sig = inspect.signature(fn) bound_args = sig.bind(*args, **kwargs) @@ -765,37 +855,24 @@ def get_arguments_dict(fn, *args, **kwargs): return arg_dict -def get_shapes_dict(arguments_dict): +def get_shapes_dict(call_spec): """Convert the call() arguments dict into a dict of input shape arguments. Example: ``` - >>> get_shapes_dict( - {"input_a": KerasTensor(shape=(2, 3)), "training": False}) + >>> get_shapes_dict(call_spec) {"input_a_shape": (2, 3)} ``` """ shapes_dict = {} - for k, v in arguments_dict.items(): - if isinstance(v, KerasTensor) or backend.is_tensor(v): - shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape) - elif nest.is_nested(v): - flat = nest.flatten(v) - if any( - isinstance(x, KerasTensor) or backend.is_tensor(x) for x in flat - ): - if not all( - isinstance(x, KerasTensor) or backend.is_tensor(x) - for x in flat - ): - raise ValueError( - "You cannot mix tensors and non-tensors in a nested " - f"argument. Invalid argument: {k}={v}" - ) + for k, v in call_spec.tensor_arguments_dict.items(): + if k in call_spec.nested_tensor_argument_names: shapes_dict[f"{k}_shape"] = nest.map_structure( lambda x: backend.standardize_shape(x.shape), v ) + else: + shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape) return shapes_dict diff --git a/keras_core/layers/layer_test.py b/keras_core/layers/layer_test.py index 23002fe11..124f70c3d 100644 --- a/keras_core/layers/layer_test.py +++ b/keras_core/layers/layer_test.py @@ -262,3 +262,73 @@ class LayerTest(testing.TestCase): self.assertEqual(layer.variable_dtype, "float32") self.assertEqual(y.dtype.name, "float16") self.assertEqual(layer.kernel.dtype, "float32") + + def test_masking(self): + class BasicMaskedLayer(layers.Layer): + def __init__(self): + super().__init__() + self.supports_masking = True + + def call(self, x, mask=None): + assert mask is not None + return x + + layer = BasicMaskedLayer() + x = backend.numpy.ones((4, 4)) + x._keras_mask = backend.numpy.ones((4,)) + layer(x) + + class NestedInputMaskedLayer(layers.Layer): + def __init__(self): + super().__init__() + self.supports_masking = True + + def call(self, x, mask=None): + assert isinstance(x, list) + assert len(x) == 2 + assert isinstance(mask, list) + assert len(mask) == 2 + return x + + layer = NestedInputMaskedLayer() + x1 = backend.numpy.ones((4, 4)) + x1._keras_mask = backend.numpy.ones((4,)) + x2 = backend.numpy.ones((4, 4)) + x2._keras_mask = backend.numpy.ones((4,)) + layer([x1, x2]) + + class PositionalInputsMaskedLayer(layers.Layer): + def __init__(self): + super().__init__() + self.supports_masking = True + + def call(self, x1, x2, x1_mask=None, x2_mask=None): + assert x1_mask is not None + assert x2_mask is not None + return x1 + x2 + + layer = PositionalInputsMaskedLayer() + layer(x1, x2) + layer(x1=x1, x2=x2) + + class PositionalNestedInputsMaskedLayer(layers.Layer): + def __init__(self): + super().__init__() + self.supports_masking = True + + def call(self, x1, x2, x1_mask=None, x2_mask=None): + assert isinstance(x1, tuple) + assert x1_mask is not None + assert x2_mask is not None + assert isinstance(x1_mask, tuple) + return x1[0] + x1[1] + x2 + + layer = PositionalNestedInputsMaskedLayer() + x1_1 = backend.numpy.ones((4, 4)) + x1_1._keras_mask = backend.numpy.ones((4,)) + x1_2 = backend.numpy.ones((4, 4)) + x1_2._keras_mask = backend.numpy.ones((4,)) + x2 = backend.numpy.ones((4, 4)) + x2._keras_mask = backend.numpy.ones((4,)) + layer((x1_1, x1_2), x2) + layer(x1=(x1_1, x1_2), x2=x2)