190 lines
6.3 KiB
Python
190 lines
6.3 KiB
Python
from keras_core import backend
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.layers.layer import Layer
|
|
from keras_core.utils import summary_utils
|
|
|
|
if backend.backend() == "tensorflow":
|
|
from keras_core.backend.tensorflow.trainer import (
|
|
TensorFlowTrainer as Trainer,
|
|
)
|
|
elif backend.backend() == "jax":
|
|
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
|
|
else:
|
|
Trainer = None
|
|
|
|
|
|
@keras_core_export(["keras_core.Model", "keras_core.models.Model"])
|
|
class Model(Trainer, Layer):
|
|
"""
|
|
|
|
Combination of a Layer and Trainer. Adds:
|
|
|
|
- layer surfacing
|
|
- saving
|
|
- export
|
|
- summary
|
|
|
|
Limitations:
|
|
|
|
- call must have a single inputs argument
|
|
- no masking support
|
|
"""
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
# Signature detection for usage of `Model` as a `Functional`
|
|
if functional_init_arguments(args, kwargs) and cls == Model:
|
|
from keras_core.models import functional
|
|
|
|
return functional.Functional(*args, **kwargs)
|
|
return super().__new__(cls)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
Trainer.__init__(self)
|
|
from keras_core.models import functional
|
|
|
|
# Signature detection for usage of a `Model` subclass
|
|
# as a `Functional` subclass
|
|
if functional_init_arguments(args, kwargs):
|
|
inject_functional_model_class(self.__class__)
|
|
functional.Functional.__init__(self, *args, **kwargs)
|
|
else:
|
|
Layer.__init__(self, *args, **kwargs)
|
|
|
|
def call(self, inputs, training=False):
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def layers(self):
|
|
return list(self._flatten_layers(include_self=False, recursive=False))
|
|
|
|
@layers.setter
|
|
def layers(self, _):
|
|
raise AttributeError(
|
|
"`Model.layers` attribute is reserved and should not be used. "
|
|
"Please use another name."
|
|
)
|
|
|
|
def get_layer(self, name=None, index=None):
|
|
"""Retrieves a layer based on either its name (unique) or index.
|
|
|
|
If `name` and `index` are both provided, `index` will take precedence.
|
|
Indices are based on order of horizontal graph traversal (bottom-up).
|
|
|
|
Args:
|
|
name: String, name of layer.
|
|
index: Integer, index of layer.
|
|
|
|
Returns:
|
|
A layer instance.
|
|
"""
|
|
if index is not None and name is not None:
|
|
raise ValueError(
|
|
"Provide only a layer name or a layer index. Received: "
|
|
f"index={index}, name={name}."
|
|
)
|
|
if index is not None:
|
|
if len(self.layers) <= index:
|
|
raise ValueError(
|
|
f"Was asked to retrieve layer at index {index}"
|
|
f" but model only has {len(self.layers)}"
|
|
" layers."
|
|
)
|
|
else:
|
|
return self.layers[index]
|
|
|
|
if name is not None:
|
|
for layer in self.layers:
|
|
if layer.name == name:
|
|
return layer
|
|
raise ValueError(
|
|
f"No such layer: {name}. Existing layers are: "
|
|
f"{list(layer.name for layer in self.layers)}."
|
|
)
|
|
raise ValueError(
|
|
"Provide either a layer name or layer index at `get_layer`."
|
|
)
|
|
|
|
def summary(
|
|
self,
|
|
line_length=None,
|
|
positions=None,
|
|
print_fn=None,
|
|
expand_nested=False,
|
|
show_trainable=False,
|
|
layer_range=None,
|
|
):
|
|
"""Prints a string summary of the network.
|
|
|
|
Args:
|
|
line_length: Total length of printed lines
|
|
(e.g. set this to adapt the display to different
|
|
terminal window sizes).
|
|
positions: Relative or absolute positions of log elements
|
|
in each line. If not provided, becomes
|
|
`[0.3, 0.6, 0.70, 1.]`. Defaults to `None`.
|
|
print_fn: Print function to use. By default, prints to `stdout`.
|
|
If `stdout` doesn't work in your environment, change to `print`.
|
|
It will be called on each line of the summary.
|
|
You can set it to a custom function
|
|
in order to capture the string summary.
|
|
expand_nested: Whether to expand the nested models.
|
|
Defaults to `False`.
|
|
show_trainable: Whether to show if a layer is trainable.
|
|
Defaults to `False`.
|
|
layer_range: a list or tuple of 2 strings,
|
|
which is the starting layer name and ending layer name
|
|
(both inclusive) indicating the range of layers to be printed
|
|
in summary. It also accepts regex patterns instead of exact
|
|
name. In such case, start predicate will be the first element
|
|
it matches to `layer_range[0]` and the end predicate will be
|
|
the last element it matches to `layer_range[1]`.
|
|
By default `None` which considers all layers of model.
|
|
|
|
Raises:
|
|
ValueError: if `summary()` is called before the model is built.
|
|
"""
|
|
summary_utils.print_summary(
|
|
self,
|
|
line_length=line_length,
|
|
positions=positions,
|
|
print_fn=print_fn,
|
|
expand_nested=expand_nested,
|
|
show_trainable=show_trainable,
|
|
layer_range=layer_range,
|
|
)
|
|
|
|
def save(self, filepath):
|
|
raise NotImplementedError
|
|
|
|
def export(self, filepath):
|
|
raise NotImplementedError
|
|
|
|
|
|
def functional_init_arguments(args, kwargs):
|
|
return (
|
|
(len(args) == 2)
|
|
or (len(args) == 1 and "outputs" in kwargs)
|
|
or ("inputs" in kwargs and "outputs" in kwargs)
|
|
)
|
|
|
|
|
|
def inject_functional_model_class(cls):
|
|
"""Inject `Functional` into the hierarchy of this class if needed."""
|
|
from keras_core.models import functional
|
|
|
|
if cls == Model:
|
|
return functional.Functional
|
|
# In case there is any multiple inheritance, we stop injecting the
|
|
# class if keras model is not in its class hierarchy.
|
|
if cls == object:
|
|
return object
|
|
|
|
cls.__bases__ = tuple(
|
|
inject_functional_model_class(base) for base in cls.__bases__
|
|
)
|
|
# Trigger any `__new__` class swapping that needed to happen on `Functional`
|
|
# but did not because functional was not in the class hierarchy.
|
|
cls.__new__(cls)
|
|
|
|
return cls
|