286 lines
9.9 KiB
Python
286 lines
9.9 KiB
Python
import os
|
|
import warnings
|
|
|
|
from keras_core import backend
|
|
from keras_core import utils
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.layers.layer import Layer
|
|
from keras_core.saving import saving_lib
|
|
from keras_core.utils import io_utils
|
|
from keras_core.utils import summary_utils
|
|
|
|
if backend.backend() == "tensorflow":
|
|
from keras_core.backend.tensorflow.trainer import (
|
|
TensorFlowTrainer as Trainer,
|
|
)
|
|
elif backend.backend() == "jax":
|
|
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
|
|
else:
|
|
Trainer = None
|
|
|
|
|
|
@keras_core_export(["keras_core.Model", "keras_core.models.Model"])
|
|
class Model(Trainer, Layer):
|
|
"""
|
|
|
|
Combination of a Layer and Trainer. Adds:
|
|
|
|
- layer surfacing
|
|
- saving
|
|
- export
|
|
- summary
|
|
|
|
Limitations:
|
|
|
|
- call must have a single inputs argument
|
|
- no masking support
|
|
"""
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
# Signature detection for usage of `Model` as a `Functional`
|
|
if functional_init_arguments(args, kwargs) and cls == Model:
|
|
from keras_core.models import functional
|
|
|
|
return functional.Functional(*args, **kwargs)
|
|
return super().__new__(cls)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
Trainer.__init__(self)
|
|
from keras_core.models import functional
|
|
|
|
# Signature detection for usage of a `Model` subclass
|
|
# as a `Functional` subclass
|
|
if functional_init_arguments(args, kwargs):
|
|
inject_functional_model_class(self.__class__)
|
|
functional.Functional.__init__(self, *args, **kwargs)
|
|
else:
|
|
Layer.__init__(self, *args, **kwargs)
|
|
|
|
def call(self, inputs, training=False):
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def layers(self):
|
|
return list(self._flatten_layers(include_self=False, recursive=False))
|
|
|
|
@layers.setter
|
|
def layers(self, _):
|
|
raise AttributeError(
|
|
"`Model.layers` attribute is reserved and should not be used. "
|
|
"Please use another name."
|
|
)
|
|
|
|
def get_layer(self, name=None, index=None):
|
|
"""Retrieves a layer based on either its name (unique) or index.
|
|
|
|
If `name` and `index` are both provided, `index` will take precedence.
|
|
Indices are based on order of horizontal graph traversal (bottom-up).
|
|
|
|
Args:
|
|
name: String, name of layer.
|
|
index: Integer, index of layer.
|
|
|
|
Returns:
|
|
A layer instance.
|
|
"""
|
|
if index is not None and name is not None:
|
|
raise ValueError(
|
|
"Provide only a layer name or a layer index. Received: "
|
|
f"index={index}, name={name}."
|
|
)
|
|
if index is not None:
|
|
if len(self.layers) <= index:
|
|
raise ValueError(
|
|
f"Was asked to retrieve layer at index {index}"
|
|
f" but model only has {len(self.layers)}"
|
|
" layers."
|
|
)
|
|
else:
|
|
return self.layers[index]
|
|
|
|
if name is not None:
|
|
for layer in self.layers:
|
|
if layer.name == name:
|
|
return layer
|
|
raise ValueError(
|
|
f"No such layer: {name}. Existing layers are: "
|
|
f"{list(layer.name for layer in self.layers)}."
|
|
)
|
|
raise ValueError(
|
|
"Provide either a layer name or layer index at `get_layer`."
|
|
)
|
|
|
|
def summary(
|
|
self,
|
|
line_length=None,
|
|
positions=None,
|
|
print_fn=None,
|
|
expand_nested=False,
|
|
show_trainable=False,
|
|
layer_range=None,
|
|
):
|
|
"""Prints a string summary of the network.
|
|
|
|
Args:
|
|
line_length: Total length of printed lines
|
|
(e.g. set this to adapt the display to different
|
|
terminal window sizes).
|
|
positions: Relative or absolute positions of log elements
|
|
in each line. If not provided, becomes
|
|
`[0.3, 0.6, 0.70, 1.]`. Defaults to `None`.
|
|
print_fn: Print function to use. By default, prints to `stdout`.
|
|
If `stdout` doesn't work in your environment, change to `print`.
|
|
It will be called on each line of the summary.
|
|
You can set it to a custom function
|
|
in order to capture the string summary.
|
|
expand_nested: Whether to expand the nested models.
|
|
Defaults to `False`.
|
|
show_trainable: Whether to show if a layer is trainable.
|
|
Defaults to `False`.
|
|
layer_range: a list or tuple of 2 strings,
|
|
which is the starting layer name and ending layer name
|
|
(both inclusive) indicating the range of layers to be printed
|
|
in summary. It also accepts regex patterns instead of exact
|
|
name. In such case, start predicate will be the first element
|
|
it matches to `layer_range[0]` and the end predicate will be
|
|
the last element it matches to `layer_range[1]`.
|
|
By default `None` which considers all layers of model.
|
|
|
|
Raises:
|
|
ValueError: if `summary()` is called before the model is built.
|
|
"""
|
|
summary_utils.print_summary(
|
|
self,
|
|
line_length=line_length,
|
|
positions=positions,
|
|
print_fn=print_fn,
|
|
expand_nested=expand_nested,
|
|
show_trainable=show_trainable,
|
|
layer_range=layer_range,
|
|
)
|
|
|
|
def save(self, filepath, overwrite=True):
|
|
if not str(filepath).endswith(".keras"):
|
|
raise ValueError(
|
|
"The filename must end in `.keras`. "
|
|
f"Received: filepath={filepath}"
|
|
)
|
|
try:
|
|
exists = os.path.exists(filepath)
|
|
except TypeError:
|
|
exists = False
|
|
if exists and not overwrite:
|
|
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
|
|
if not proceed:
|
|
return
|
|
saving_lib.save_model(self, filepath)
|
|
|
|
def save_weights(self, filepath, overwrite=True):
|
|
if not str(filepath).endswith(".weights.h5"):
|
|
raise ValueError(
|
|
"The filename must end in `.weights.h5`. "
|
|
f"Received: filepath={filepath}"
|
|
)
|
|
try:
|
|
exists = os.path.exists(filepath)
|
|
except TypeError:
|
|
exists = False
|
|
if exists and not overwrite:
|
|
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
|
|
if not proceed:
|
|
return
|
|
saving_lib.save_weights_only(self, filepath)
|
|
|
|
def load_weights(self, filepath, skip_mismatch=False):
|
|
if str(filepath).endswith(".keras"):
|
|
saving_lib.load_weights_only(
|
|
self, filepath, skip_mismatch=skip_mismatch
|
|
)
|
|
elif str(filepath).endswith(".weights.h5"):
|
|
saving_lib.load_weights_only(
|
|
self, filepath, skip_mismatch=skip_mismatch
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"File format not supported: filepath={filepath}. "
|
|
"Keras Core only supports V3 `.keras` and `.weights.h5` "
|
|
"files."
|
|
)
|
|
|
|
def build_from_config(self, config):
|
|
if not config:
|
|
return
|
|
if "input_shape" in config:
|
|
# Case: all inputs are in the first arg (possibly nested).
|
|
if utils.is_default(self.build):
|
|
status = self._build_by_run_for_single_pos_arg(
|
|
config["input_shape"]
|
|
)
|
|
else:
|
|
try:
|
|
self.build(config["input_shape"])
|
|
status = True
|
|
except:
|
|
status = False
|
|
self._build_shapes_dict = config
|
|
|
|
elif "shapes_dict" in config:
|
|
# Case: inputs were recorded as multiple keyword arguments.
|
|
if utils.is_default(self.build):
|
|
status = self._build_for_kwargs(config["shapes_dict"])
|
|
else:
|
|
try:
|
|
self.build(**config["shapes_dict"])
|
|
status = True
|
|
except:
|
|
status = False
|
|
self._build_shapes_dict = config["shapes_dict"]
|
|
|
|
if not status:
|
|
warnings.warn(
|
|
f"Model '{self.name}' had a build config, but the model "
|
|
"cannot be built automatically in "
|
|
"`build_from_config(config)`. "
|
|
"You should implement "
|
|
"`def build_from_config(self, config)`, "
|
|
"and you might also want to implement the method "
|
|
" that generates the config at saving time, "
|
|
"`def get_build_config(self)`. "
|
|
"The method `build_from_config()` is meant to "
|
|
"create the state of the model (i.e. its variables) "
|
|
"upon deserialization.",
|
|
stacklevel=2,
|
|
)
|
|
|
|
def export(self, filepath):
|
|
raise NotImplementedError
|
|
|
|
|
|
def functional_init_arguments(args, kwargs):
|
|
return (
|
|
(len(args) == 2)
|
|
or (len(args) == 1 and "outputs" in kwargs)
|
|
or ("inputs" in kwargs and "outputs" in kwargs)
|
|
)
|
|
|
|
|
|
def inject_functional_model_class(cls):
|
|
"""Inject `Functional` into the hierarchy of this class if needed."""
|
|
from keras_core.models import functional
|
|
|
|
if cls == Model:
|
|
return functional.Functional
|
|
# In case there is any multiple inheritance, we stop injecting the
|
|
# class if keras model is not in its class hierarchy.
|
|
if cls == object:
|
|
return object
|
|
|
|
cls.__bases__ = tuple(
|
|
inject_functional_model_class(base) for base in cls.__bases__
|
|
)
|
|
# Trigger any `__new__` class swapping that needed to happen on `Functional`
|
|
# but did not because functional was not in the class hierarchy.
|
|
cls.__new__(cls)
|
|
|
|
return cls
|