import os import warnings from keras_core import backend from keras_core import utils from keras_core.api_export import keras_core_export from keras_core.layers.layer import Layer from keras_core.saving import saving_lib from keras_core.utils import io_utils from keras_core.utils import summary_utils if backend.backend() == "tensorflow": from keras_core.backend.tensorflow.trainer import ( TensorFlowTrainer as Trainer, ) elif backend.backend() == "jax": from keras_core.backend.jax.trainer import JAXTrainer as Trainer else: Trainer = None @keras_core_export(["keras_core.Model", "keras_core.models.Model"]) class Model(Trainer, Layer): """ Combination of a Layer and Trainer. Adds: - layer surfacing - saving - export - summary Limitations: - call must have a single inputs argument - no masking support """ def __new__(cls, *args, **kwargs): # Signature detection for usage of `Model` as a `Functional` if functional_init_arguments(args, kwargs) and cls == Model: from keras_core.models import functional return functional.Functional(*args, **kwargs) return super().__new__(cls) def __init__(self, *args, **kwargs): Trainer.__init__(self) from keras_core.models import functional # Signature detection for usage of a `Model` subclass # as a `Functional` subclass if functional_init_arguments(args, kwargs): inject_functional_model_class(self.__class__) functional.Functional.__init__(self, *args, **kwargs) else: Layer.__init__(self, *args, **kwargs) def call(self, inputs, training=False): raise NotImplementedError @property def layers(self): return list(self._flatten_layers(include_self=False, recursive=False)) @layers.setter def layers(self, _): raise AttributeError( "`Model.layers` attribute is reserved and should not be used. " "Please use another name." ) def get_layer(self, name=None, index=None): """Retrieves a layer based on either its name (unique) or index. If `name` and `index` are both provided, `index` will take precedence. Indices are based on order of horizontal graph traversal (bottom-up). Args: name: String, name of layer. index: Integer, index of layer. Returns: A layer instance. """ if index is not None and name is not None: raise ValueError( "Provide only a layer name or a layer index. Received: " f"index={index}, name={name}." ) if index is not None: if len(self.layers) <= index: raise ValueError( f"Was asked to retrieve layer at index {index}" f" but model only has {len(self.layers)}" " layers." ) else: return self.layers[index] if name is not None: for layer in self.layers: if layer.name == name: return layer raise ValueError( f"No such layer: {name}. Existing layers are: " f"{list(layer.name for layer in self.layers)}." ) raise ValueError( "Provide either a layer name or layer index at `get_layer`." ) def summary( self, line_length=None, positions=None, print_fn=None, expand_nested=False, show_trainable=False, layer_range=None, ): """Prints a string summary of the network. Args: line_length: Total length of printed lines (e.g. set this to adapt the display to different terminal window sizes). positions: Relative or absolute positions of log elements in each line. If not provided, becomes `[0.3, 0.6, 0.70, 1.]`. Defaults to `None`. print_fn: Print function to use. By default, prints to `stdout`. If `stdout` doesn't work in your environment, change to `print`. It will be called on each line of the summary. You can set it to a custom function in order to capture the string summary. expand_nested: Whether to expand the nested models. Defaults to `False`. show_trainable: Whether to show if a layer is trainable. Defaults to `False`. layer_range: a list or tuple of 2 strings, which is the starting layer name and ending layer name (both inclusive) indicating the range of layers to be printed in summary. It also accepts regex patterns instead of exact name. In such case, start predicate will be the first element it matches to `layer_range[0]` and the end predicate will be the last element it matches to `layer_range[1]`. By default `None` which considers all layers of model. Raises: ValueError: if `summary()` is called before the model is built. """ summary_utils.print_summary( self, line_length=line_length, positions=positions, print_fn=print_fn, expand_nested=expand_nested, show_trainable=show_trainable, layer_range=layer_range, ) def save(self, filepath, overwrite=True): if not str(filepath).endswith(".keras"): raise ValueError( "The filename must end in `.keras`. " f"Received: filepath={filepath}" ) try: exists = os.path.exists(filepath) except TypeError: exists = False if exists and not overwrite: proceed = io_utils.ask_to_proceed_with_overwrite(filepath) if not proceed: return saving_lib.save_model(self, filepath) def save_weights(self, filepath, overwrite=True): if not str(filepath).endswith(".weights.h5"): raise ValueError( "The filename must end in `.weights.h5`. " f"Received: filepath={filepath}" ) try: exists = os.path.exists(filepath) except TypeError: exists = False if exists and not overwrite: proceed = io_utils.ask_to_proceed_with_overwrite(filepath) if not proceed: return saving_lib.save_weights_only(self, filepath) def load_weights(self, filepath, skip_mismatch=False): if str(filepath).endswith(".keras"): saving_lib.load_weights_only( self, filepath, skip_mismatch=skip_mismatch ) elif str(filepath).endswith(".weights.h5"): saving_lib.load_weights_only( self, filepath, skip_mismatch=skip_mismatch ) else: raise ValueError( f"File format not supported: filepath={filepath}. " "Keras Core only supports V3 `.keras` and `.weights.h5` " "files." ) def build_from_config(self, config): if not config: return if "input_shape" in config: # Case: all inputs are in the first arg (possibly nested). if utils.is_default(self.build): status = self._build_by_run_for_single_pos_arg( config["input_shape"] ) else: try: self.build(config["input_shape"]) status = True except: status = False self._build_shapes_dict = config elif "shapes_dict" in config: # Case: inputs were recorded as multiple keyword arguments. if utils.is_default(self.build): status = self._build_for_kwargs(config["shapes_dict"]) else: try: self.build(**config["shapes_dict"]) status = True except: status = False self._build_shapes_dict = config["shapes_dict"] if not status: warnings.warn( f"Model '{self.name}' had a build config, but the model " "cannot be built automatically in " "`build_from_config(config)`. " "You should implement " "`def build_from_config(self, config)`, " "and you might also want to implement the method " " that generates the config at saving time, " "`def get_build_config(self)`. " "The method `build_from_config()` is meant to " "create the state of the model (i.e. its variables) " "upon deserialization.", stacklevel=2, ) def export(self, filepath): raise NotImplementedError def functional_init_arguments(args, kwargs): return ( (len(args) == 2) or (len(args) == 1 and "outputs" in kwargs) or ("inputs" in kwargs and "outputs" in kwargs) ) def inject_functional_model_class(cls): """Inject `Functional` into the hierarchy of this class if needed.""" from keras_core.models import functional if cls == Model: return functional.Functional # In case there is any multiple inheritance, we stop injecting the # class if keras model is not in its class hierarchy. if cls == object: return object cls.__bases__ = tuple( inject_functional_model_class(base) for base in cls.__bases__ ) # Trigger any `__new__` class swapping that needed to happen on `Functional` # but did not because functional was not in the class hierarchy. cls.__new__(cls) return cls