keras/keras_core/models/model.py

286 lines
9.9 KiB
Python
Raw Normal View History

2023-04-25 01:54:23 +00:00
import os
2023-04-25 19:59:32 +00:00
import warnings
2023-04-25 01:54:23 +00:00
2023-04-18 21:49:38 +00:00
from keras_core import backend
2023-04-26 18:59:47 +00:00
from keras_core import utils
2023-04-09 19:21:45 +00:00
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
2023-04-25 01:54:23 +00:00
from keras_core.saving import saving_lib
from keras_core.utils import io_utils
2023-04-13 17:59:51 +00:00
from keras_core.utils import summary_utils
if backend.backend() == "tensorflow":
2023-04-21 17:00:32 +00:00
from keras_core.backend.tensorflow.trainer import (
TensorFlowTrainer as Trainer,
)
2023-04-18 04:26:04 +00:00
elif backend.backend() == "jax":
2023-04-20 21:59:20 +00:00
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
else:
Trainer = None
2023-04-09 19:21:45 +00:00
@keras_core_export(["keras_core.Model", "keras_core.models.Model"])
class Model(Trainer, Layer):
2023-04-09 19:21:45 +00:00
"""
Combination of a Layer and Trainer. Adds:
- layer surfacing
- saving
- export
- summary
Limitations:
- call must have a single inputs argument
- no masking support
"""
2023-04-20 21:50:03 +00:00
def __new__(cls, *args, **kwargs):
2023-04-21 22:01:17 +00:00
# Signature detection for usage of `Model` as a `Functional`
if functional_init_arguments(args, kwargs) and cls == Model:
2023-04-20 21:50:03 +00:00
from keras_core.models import functional
2023-04-23 18:05:04 +00:00
return functional.Functional(*args, **kwargs)
2023-04-21 22:01:17 +00:00
return super().__new__(cls)
2023-04-09 19:21:45 +00:00
2023-04-21 22:01:17 +00:00
def __init__(self, *args, **kwargs):
Trainer.__init__(self)
2023-04-21 22:01:17 +00:00
from keras_core.models import functional
# Signature detection for usage of a `Model` subclass
# as a `Functional` subclass
if functional_init_arguments(args, kwargs):
inject_functional_model_class(self.__class__)
functional.Functional.__init__(self, *args, **kwargs)
else:
Layer.__init__(self, *args, **kwargs)
2023-04-09 19:21:45 +00:00
def call(self, inputs, training=False):
raise NotImplementedError
@property
def layers(self):
return list(self._flatten_layers(include_self=False, recursive=False))
2023-04-09 19:21:45 +00:00
@layers.setter
def layers(self, _):
raise AttributeError(
"`Model.layers` attribute is reserved and should not be used. "
"Please use another name."
)
def get_layer(self, name=None, index=None):
"""Retrieves a layer based on either its name (unique) or index.
If `name` and `index` are both provided, `index` will take precedence.
Indices are based on order of horizontal graph traversal (bottom-up).
Args:
name: String, name of layer.
index: Integer, index of layer.
Returns:
A layer instance.
"""
if index is not None and name is not None:
raise ValueError(
"Provide only a layer name or a layer index. Received: "
f"index={index}, name={name}."
)
if index is not None:
if len(self.layers) <= index:
raise ValueError(
f"Was asked to retrieve layer at index {index}"
f" but model only has {len(self.layers)}"
" layers."
)
else:
return self.layers[index]
if name is not None:
for layer in self.layers:
if layer.name == name:
return layer
raise ValueError(
f"No such layer: {name}. Existing layers are: "
f"{list(layer.name for layer in self.layers)}."
)
2023-04-12 17:52:34 +00:00
raise ValueError(
"Provide either a layer name or layer index at `get_layer`."
)
2023-04-09 19:21:45 +00:00
def summary(
self,
line_length=None,
positions=None,
print_fn=None,
expand_nested=False,
show_trainable=False,
layer_range=None,
):
"""Prints a string summary of the network.
Args:
line_length: Total length of printed lines
(e.g. set this to adapt the display to different
terminal window sizes).
positions: Relative or absolute positions of log elements
in each line. If not provided, becomes
`[0.3, 0.6, 0.70, 1.]`. Defaults to `None`.
print_fn: Print function to use. By default, prints to `stdout`.
If `stdout` doesn't work in your environment, change to `print`.
It will be called on each line of the summary.
You can set it to a custom function
in order to capture the string summary.
expand_nested: Whether to expand the nested models.
Defaults to `False`.
show_trainable: Whether to show if a layer is trainable.
Defaults to `False`.
layer_range: a list or tuple of 2 strings,
which is the starting layer name and ending layer name
(both inclusive) indicating the range of layers to be printed
in summary. It also accepts regex patterns instead of exact
name. In such case, start predicate will be the first element
it matches to `layer_range[0]` and the end predicate will be
the last element it matches to `layer_range[1]`.
By default `None` which considers all layers of model.
Raises:
ValueError: if `summary()` is called before the model is built.
"""
2023-04-13 17:59:51 +00:00
summary_utils.print_summary(
2023-04-09 19:21:45 +00:00
self,
line_length=line_length,
positions=positions,
print_fn=print_fn,
expand_nested=expand_nested,
show_trainable=show_trainable,
layer_range=layer_range,
)
2023-04-25 01:54:23 +00:00
def save(self, filepath, overwrite=True):
2023-04-26 21:54:00 +00:00
if not str(filepath).endswith(".keras"):
2023-04-25 01:54:23 +00:00
raise ValueError(
"The filename must end in `.keras`. "
f"Received: filepath={filepath}"
)
try:
exists = os.path.exists(filepath)
except TypeError:
exists = False
if exists and not overwrite:
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
if not proceed:
return
saving_lib.save_model(self, filepath)
def save_weights(self, filepath, overwrite=True):
2023-04-26 21:54:00 +00:00
if not str(filepath).endswith(".weights.h5"):
2023-04-25 01:54:23 +00:00
raise ValueError(
"The filename must end in `.weights.h5`. "
f"Received: filepath={filepath}"
)
try:
exists = os.path.exists(filepath)
except TypeError:
exists = False
if exists and not overwrite:
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
if not proceed:
return
saving_lib.save_weights_only(self, filepath)
def load_weights(self, filepath, skip_mismatch=False):
if str(filepath).endswith(".keras"):
saving_lib.load_weights_only(
self, filepath, skip_mismatch=skip_mismatch
)
elif str(filepath).endswith(".weights.h5"):
saving_lib.load_weights_only(
self, filepath, skip_mismatch=skip_mismatch
)
else:
raise ValueError(
f"File format not supported: filepath={filepath}. "
"Keras Core only supports V3 `.keras` and `.weights.h5` "
"files."
)
2023-04-09 19:21:45 +00:00
2023-04-25 19:59:32 +00:00
def build_from_config(self, config):
if not config:
return
if "input_shape" in config:
# Case: all inputs are in the first arg (possibly nested).
2023-04-26 18:57:46 +00:00
if utils.is_default(self.build):
2023-04-26 18:59:47 +00:00
status = self._build_by_run_for_single_pos_arg(
config["input_shape"]
)
2023-04-26 18:57:46 +00:00
else:
try:
self.build(config["input_shape"])
status = True
except:
status = False
self._build_shapes_dict = config
2023-04-26 18:57:46 +00:00
elif "shapes_dict" in config:
# Case: inputs were recorded as multiple keyword arguments.
2023-04-26 18:57:46 +00:00
if utils.is_default(self.build):
status = self._build_for_kwargs(config["shapes_dict"])
else:
try:
self.build(**config["shapes_dict"])
status = True
except:
status = False
self._build_shapes_dict = config["shapes_dict"]
2023-04-26 18:59:47 +00:00
if not status:
warnings.warn(
f"Model '{self.name}' had a build config, but the model "
"cannot be built automatically in "
"`build_from_config(config)`. "
"You should implement "
"`def build_from_config(self, config)`, "
"and you might also want to implement the method "
" that generates the config at saving time, "
"`def get_build_config(self)`. "
"The method `build_from_config()` is meant to "
"create the state of the model (i.e. its variables) "
"upon deserialization.",
stacklevel=2,
2023-04-25 19:59:32 +00:00
)
2023-04-09 19:21:45 +00:00
def export(self, filepath):
raise NotImplementedError
def functional_init_arguments(args, kwargs):
return (
2023-04-13 17:59:51 +00:00
(len(args) == 2)
or (len(args) == 1 and "outputs" in kwargs)
or ("inputs" in kwargs and "outputs" in kwargs)
2023-04-09 19:21:45 +00:00
)
2023-04-21 22:01:17 +00:00
def inject_functional_model_class(cls):
"""Inject `Functional` into the hierarchy of this class if needed."""
from keras_core.models import functional
if cls == Model:
return functional.Functional
# In case there is any multiple inheritance, we stop injecting the
# class if keras model is not in its class hierarchy.
if cls == object:
return object
cls.__bases__ = tuple(
inject_functional_model_class(base) for base in cls.__bases__
)
# Trigger any `__new__` class swapping that needed to happen on `Functional`
# but did not because functional was not in the class hierarchy.
cls.__new__(cls)
return cls