157 lines
5.2 KiB
Python
157 lines
5.2 KiB
Python
from keras_core.api_export import keras_core_export
|
|
from keras_core.layers.layer import Layer
|
|
from keras_core.trainers.trainer import Trainer
|
|
|
|
|
|
@keras_core_export(["keras_core.Model", "keras_core.models.Model"])
|
|
class Model(Layer, Trainer):
|
|
"""
|
|
|
|
Combination of a Layer and Trainer. Adds:
|
|
|
|
- layer surfacing
|
|
- saving
|
|
- export
|
|
- summary
|
|
|
|
Limitations:
|
|
|
|
- call must have a single inputs argument
|
|
- no masking support
|
|
"""
|
|
|
|
# def __new__(cls, *args, **kwargs):
|
|
# # Signature detection
|
|
# if functional_init_arguments(args, kwargs) and cls == Model:
|
|
# # Functional model
|
|
# from keras_core.models import functional
|
|
|
|
# return functional.Functional(*args, **kwargs, skip_init=True)
|
|
# else:
|
|
# return super(Model, cls).__new__(cls, *args, **kwargs)
|
|
|
|
def call(self, inputs, training=False):
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def layers(self):
|
|
return list(self._flatten_layers(include_self=False, recursive=False))
|
|
|
|
@layers.setter
|
|
def layers(self, _):
|
|
raise AttributeError(
|
|
"`Model.layers` attribute is reserved and should not be used. "
|
|
"Please use another name."
|
|
)
|
|
|
|
def get_layer(self, name=None, index=None):
|
|
"""Retrieves a layer based on either its name (unique) or index.
|
|
|
|
If `name` and `index` are both provided, `index` will take precedence.
|
|
Indices are based on order of horizontal graph traversal (bottom-up).
|
|
|
|
Args:
|
|
name: String, name of layer.
|
|
index: Integer, index of layer.
|
|
|
|
Returns:
|
|
A layer instance.
|
|
"""
|
|
if index is not None and name is not None:
|
|
raise ValueError(
|
|
"Provide only a layer name or a layer index. Received: "
|
|
f"index={index}, name={name}."
|
|
)
|
|
if index is not None:
|
|
if len(self.layers) <= index:
|
|
raise ValueError(
|
|
f"Was asked to retrieve layer at index {index}"
|
|
f" but model only has {len(self.layers)}"
|
|
" layers."
|
|
)
|
|
else:
|
|
return self.layers[index]
|
|
|
|
if name is not None:
|
|
for layer in self.layers:
|
|
if layer.name == name:
|
|
return layer
|
|
raise ValueError(
|
|
f"No such layer: {name}. Existing layers are: "
|
|
f"{list(layer.name for layer in self.layers)}."
|
|
)
|
|
raise ValueError(
|
|
"Provide either a layer name or layer index at `get_layer`."
|
|
)
|
|
|
|
def summary(
|
|
self,
|
|
line_length=None,
|
|
positions=None,
|
|
print_fn=None,
|
|
expand_nested=False,
|
|
show_trainable=False,
|
|
layer_range=None,
|
|
):
|
|
"""Prints a string summary of the network.
|
|
|
|
Args:
|
|
line_length: Total length of printed lines
|
|
(e.g. set this to adapt the display to different
|
|
terminal window sizes).
|
|
positions: Relative or absolute positions of log elements
|
|
in each line. If not provided, becomes
|
|
`[0.3, 0.6, 0.70, 1.]`. Defaults to `None`.
|
|
print_fn: Print function to use. By default, prints to `stdout`.
|
|
If `stdout` doesn't work in your environment, change to `print`.
|
|
It will be called on each line of the summary.
|
|
You can set it to a custom function
|
|
in order to capture the string summary.
|
|
expand_nested: Whether to expand the nested models.
|
|
Defaults to `False`.
|
|
show_trainable: Whether to show if a layer is trainable.
|
|
Defaults to `False`.
|
|
layer_range: a list or tuple of 2 strings,
|
|
which is the starting layer name and ending layer name
|
|
(both inclusive) indicating the range of layers to be printed
|
|
in summary. It also accepts regex patterns instead of exact
|
|
name. In such case, start predicate will be the first element
|
|
it matches to `layer_range[0]` and the end predicate will be
|
|
the last element it matches to `layer_range[1]`.
|
|
By default `None` which considers all layers of model.
|
|
|
|
Raises:
|
|
ValueError: if `summary()` is called before the model is built.
|
|
"""
|
|
if not self.built:
|
|
raise ValueError(
|
|
"This model has not yet been built. "
|
|
"Build the model first by calling `build()` or by calling "
|
|
"the model on a batch of data."
|
|
)
|
|
layer_utils.print_summary(
|
|
self,
|
|
line_length=line_length,
|
|
positions=positions,
|
|
print_fn=print_fn,
|
|
expand_nested=expand_nested,
|
|
show_trainable=show_trainable,
|
|
layer_range=layer_range,
|
|
)
|
|
|
|
def save(self, filepath):
|
|
raise NotImplementedError
|
|
|
|
def export(self, filepath):
|
|
raise NotImplementedError
|
|
|
|
|
|
def functional_init_arguments(args, kwargs):
|
|
return (
|
|
len(args) == 2
|
|
or len(args) == 1
|
|
and "outputs" in kwargs
|
|
or "inputs" in kwargs
|
|
and "outputs" in kwargs
|
|
)
|