import warnings from tensorflow import nest from keras_core import backend from keras_core import operations as ops from keras_core.layers.layer import Layer from keras_core.models.model import Model from keras_core.operations.function import Function from keras_core.utils import tracking class Functional(Function, Model): """ Add support for extra call arguments compared to Function: training, masks Add support for arg standardization: - list/dict duality - upranking Override .layers Symbolic add_loss """ @tracking.no_automatic_dependency_tracking def __init__(self, inputs, outputs, name=None, **kwargs): # This is used by the Model class, since we have some logic to swap the # class in the __new__ method, which will lead to __init__ get invoked # twice. Using the skip_init to skip one of the invocation of __init__ # to avoid any side effects skip_init = kwargs.pop("skip_init", False) if skip_init: return if isinstance(inputs, dict): for k, v in inputs.items(): if not isinstance(v, backend.KerasTensor): raise ValueError( "When providing an input dict, all values in the dict " f"must be KerasTensors. Received: inputs={inputs} including " f"invalid value {v} of type {type(v)}" ) if k != v.name: raise ValueError( "When providing an input dict, all keys in the dict " "must match the names of the corresponding tensors. " f"Received key '{k}' mapping to value {v} which has name '{v.name}'. " f"Change the tensor name to '{k}' (via `Input(..., name='{k}')`)" ) super().__init__(inputs, outputs, name=name, **kwargs) self._layers = self.layers self.built = True @property def layers(self): layers = [] for operation in self._operations: if isinstance(operation, Layer): layers.append(operation) return layers def call(self, inputs, training=None, mask=None): # Add support for traning, masking inputs = self._standardize_inputs(inputs) if mask is None: masks = [None] * len(inputs) else: masks = self._flatten_to_reference_inputs(mask) for x, mask in zip(inputs, masks): x._keras_mask = mask return self._run_through_graph( inputs, operation_fn=lambda op: operation_fn(op, training=training) ) def compute_output_spec(self, inputs, training=None, mask=None): # From Function return super().compute_output_spec(inputs) def _assert_input_compatibility(self, *args): return super(Model, self)._assert_input_compatibility(*args) def _flatten_to_reference_inputs(self, inputs): if isinstance(inputs, dict): ref_inputs = self._inputs_struct if not nest.is_nested(ref_inputs): ref_inputs = [self._nested_inputs] if isinstance(ref_inputs, dict): # In the case that the graph is constructed with dict input # tensors, We will use the original dict key to map with the # keys in the input data. Note that the model.inputs is using # nest.flatten to process the input tensors, which means the # dict input tensors are ordered by their keys. ref_input_names = sorted(ref_inputs.keys()) else: ref_input_names = [ inp._keras_history.operation.name for inp in ref_inputs ] # Raise an warning if there are more input data comparing to input # tensor if len(inputs) > len(ref_input_names): warnings.warn( "Input dict contained keys {} which did not match any " "model input. They will be ignored by the model.".format( [n for n in inputs.keys() if n not in ref_input_names] ), stacklevel=2, ) # Flatten in the order `Input`s were passed during Model # construction. return [inputs[n] for n in ref_input_names] # Otherwise both self.inputs and tensors will already be in same order. return nest.flatten(inputs) def _adjust_input_rank(self, flat_inputs): flat_ref_shapes = [x.shape for x in self._inputs] names = [x.name for x in self._inputs] adjusted = [] for x, ref_shape in zip(flat_inputs, flat_ref_shapes): x_rank = len(x.shape) ref_rank = len(ref_shape) if x_rank == ref_rank: adjusted.append(x) continue if x_rank == ref_rank + 1: if x.shape[-1] == 1: adjusted.append(ops.squeeze(x, axis=-1)) continue if x_rank == ref_rank - 1: if ref_shape[-1] == 1: adjusted.append(ops.expand_dims(x, axis=-1)) continue raise ValueError( f"Invalid input shape for input {x}. Expected shape " f"{ref_shape}, but input has incompatible shape {x.shape}" ) # Add back metadata. for i in range(len(flat_inputs)): if hasattr(flat_inputs[i], "_keras_history"): adjusted[i]._keras_history = flat_inputs[i]._keras_history if hasattr(flat_inputs[i], "_keras_mask"): adjusted[i]._keras_mask = flat_inputs[i]._keras_mask return adjusted def _standardize_inputs(self, inputs): flat_inputs = self._flatten_to_reference_inputs(inputs) return self._adjust_input_rank(flat_inputs) def add_loss(self, loss): # Symbolic only. TODO raise NotImplementedError def operation_fn(operation, training): def call(*args, **kwargs): if ( hasattr(operation, "_call_has_training_arg") and operation._call_has_training_arg() ): kwargs["training"] = training return operation(*args, **kwargs) return call