285 lines
10 KiB
Python
285 lines
10 KiB
Python
from tensorflow import nest
|
|
|
|
from keras_core import backend
|
|
from keras_core import utils
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.layers import Input
|
|
from keras_core.layers import InputLayer
|
|
from keras_core.models.functional import Functional
|
|
from keras_core.models.functional import functional_like_constructor
|
|
from keras_core.models.sequential import Sequential
|
|
from keras_core.saving import serialization_lib
|
|
|
|
|
|
@keras_core_export("keras_core.models.clone_model")
|
|
def clone_model(model, input_tensors=None, clone_function=None):
|
|
"""Clone a Functional or Sequential `Model` instance.
|
|
|
|
Model cloning is similar to calling a model on new inputs,
|
|
except that it creates new layers (and thus new weights) instead
|
|
of sharing the weights of the existing layers.
|
|
|
|
Note that
|
|
`clone_model` will not preserve the uniqueness of shared objects within the
|
|
model (e.g. a single variable attached to two distinct layers will be
|
|
restored as two separate variables).
|
|
|
|
Args:
|
|
model: Instance of `Model`
|
|
(could be a Functional model or a Sequential model).
|
|
input_tensors: optional list of input tensors or InputLayer objects
|
|
to build the model upon. If not provided,
|
|
new `Input` objects will be created.
|
|
clone_function: Callable to be used to clone each layer in the target
|
|
model (except `Input` instances). It takes as argument the
|
|
layer instance to be cloned, and returns the corresponding layer
|
|
instance to be used in the model copy. If unspecified, this callable
|
|
becomes the following serialization/deserialization function:
|
|
`lambda layer: layer.__class__.from_config(layer.get_config())`.
|
|
By passing a custom callable, you can customize your copy of the
|
|
model, e.g. by wrapping certain layers of interest (you might want
|
|
to replace all `LSTM` instances with equivalent
|
|
`Bidirectional(LSTM(...))` instances, for example).
|
|
Defaults to `None`.
|
|
|
|
Returns:
|
|
An instance of `Model` reproducing the behavior
|
|
of the original model, on top of new inputs tensors,
|
|
using newly instantiated weights. The cloned model may behave
|
|
differently from the original model if a custom `clone_function`
|
|
modifies the layer.
|
|
|
|
Examples:
|
|
|
|
Basic usage:
|
|
|
|
```python
|
|
# Create a test Sequential model.
|
|
model = keras_core.Sequential([
|
|
keras_core.layers.Input(shape=(728,)),
|
|
keras_core.layers.Dense(32, activation='relu'),
|
|
keras_core.layers.Dense(1, activation='sigmoid'),
|
|
])
|
|
# Create a copy of the test model (with freshly initialized weights).
|
|
new_model = clone_model(model)
|
|
```
|
|
|
|
Using a `clone_function` to make a model deterministic by setting the
|
|
random seed everywhere:
|
|
|
|
```python
|
|
def clone_function(layer):
|
|
config = layer.get_config()
|
|
if "seed" in config:
|
|
config["seed"] = 1337
|
|
return layer.__class__.from_config(config)
|
|
|
|
new_model = clone_model(model)
|
|
```
|
|
|
|
Note that subclassed models cannot be cloned by default,
|
|
since their internal layer structure is not known.
|
|
To achieve equivalent functionality
|
|
as `clone_model` in the case of a subclassed model, simply make sure
|
|
that the model class implements `get_config()`
|
|
(and optionally `from_config()`), and call:
|
|
|
|
```python
|
|
new_model = model.__class__.from_config(model.get_config())
|
|
```
|
|
|
|
In the case of a subclassed model, you cannot using a custom
|
|
`clone_function`.
|
|
"""
|
|
if isinstance(model, Sequential):
|
|
return _clone_sequential_model(
|
|
model, input_tensors=input_tensors, clone_function=clone_function
|
|
)
|
|
if isinstance(model, Functional):
|
|
# If the get_config() method is the same as a regular Functional
|
|
# model, we're safe to use _clone_functional_model (which relies
|
|
# on a Functional constructor). In the case where the get_config
|
|
# is custom, this may not necessarily work, but if clone_function
|
|
# or input_tensors are passed, we attempt it anyway
|
|
# in order to preserve backwards compatibility.
|
|
if utils.is_default(model.get_config) or (
|
|
clone_function or input_tensors
|
|
):
|
|
return _clone_functional_model(
|
|
model,
|
|
input_tensors=input_tensors,
|
|
clone_function=clone_function,
|
|
)
|
|
|
|
# Case of a custom model class
|
|
if clone_function or input_tensors:
|
|
raise ValueError(
|
|
"Arguments clone_function and input_tensors "
|
|
"are only supported for Sequential models "
|
|
"or Functional models. Received model of "
|
|
f"type '{model.__class__.__name__}', with "
|
|
f"clone_function={clone_function} and "
|
|
f"input_tensors={input_tensors}"
|
|
)
|
|
config = serialization_lib.serialize_keras_object(model)
|
|
return serialization_lib.deserialize_keras_object(
|
|
config, custom_objects={model.__class__.__name__: model.__class__}
|
|
)
|
|
|
|
|
|
def _clone_sequential_model(model, input_tensors=None, clone_function=None):
|
|
"""Clone a `Sequential` model instance.
|
|
|
|
Model cloning is similar to calling a model on new inputs,
|
|
except that it creates new layers (and thus new weights) instead
|
|
of sharing the weights of the existing layers.
|
|
|
|
Args:
|
|
model: Instance of `Sequential`.
|
|
input_tensors: optional list of input tensors
|
|
to build the model upon. If not provided,
|
|
placeholders will be created.
|
|
clone_function: callable to be applied on non-input layers in the model.
|
|
By default, it clones the layer (without copying the weights).
|
|
|
|
Returns:
|
|
An instance of `Sequential` reproducing the behavior
|
|
of the original model, on top of new inputs tensors,
|
|
using newly instantiated weights.
|
|
"""
|
|
if clone_function is None:
|
|
|
|
def _clone_layer(layer):
|
|
return layer.__class__.from_config(layer.get_config())
|
|
|
|
clone_function = _clone_layer
|
|
|
|
if not isinstance(model, Sequential):
|
|
raise ValueError(
|
|
"Expected `model` argument "
|
|
"to be a `Sequential` model instance. "
|
|
f"Received: model={model}"
|
|
)
|
|
|
|
if not callable(clone_function):
|
|
raise ValueError(
|
|
"Expected `clone_function` argument to be a callable. "
|
|
f"Received: clone_function={clone_function}"
|
|
)
|
|
|
|
new_layers = [clone_function(layer) for layer in model.layers]
|
|
|
|
if isinstance(model._layers[0], InputLayer):
|
|
ref_input_layer = model._layers[0]
|
|
input_name = ref_input_layer.name
|
|
input_batch_shape = ref_input_layer.batch_shape
|
|
input_dtype = ref_input_layer._dtype
|
|
else:
|
|
input_name = None
|
|
input_dtype = None
|
|
input_batch_shape = None
|
|
|
|
if input_tensors:
|
|
if isinstance(input_tensors, (list, tuple)):
|
|
if len(input_tensors) != 1:
|
|
raise ValueError(
|
|
"Argument `input_tensors` must contain a single tensor."
|
|
)
|
|
input_tensors = input_tensors[0]
|
|
if not isinstance(input_tensors, backend.KerasTensor):
|
|
raise ValueError(
|
|
"Argument `input_tensors` must be a KerasTensor. "
|
|
f"Received invalid value: input_tensors={input_tensors}"
|
|
)
|
|
inputs = Input(tensor=input_tensors, name=input_name)
|
|
new_layers = [inputs] + new_layers
|
|
else:
|
|
if input_batch_shape is not None:
|
|
inputs = Input(
|
|
tensor=input_tensors,
|
|
batch_shape=input_batch_shape,
|
|
dtype=input_dtype,
|
|
name=input_name,
|
|
)
|
|
new_layers = [inputs] + new_layers
|
|
return Sequential(new_layers, name=model.name, trainable=model.trainable)
|
|
|
|
|
|
def _clone_functional_model(model, input_tensors=None, clone_function=None):
|
|
"""Clone a `Functional` model instance.
|
|
|
|
Model cloning is similar to calling a model on new inputs,
|
|
except that it creates new layers (and thus new weights) instead
|
|
of sharing the weights of the existing layers.
|
|
|
|
Input layers are always cloned.
|
|
|
|
Args:
|
|
model: Instance of `Functional`.
|
|
input_tensors: optional list of input tensors
|
|
to build the model upon. If not provided,
|
|
placeholders will be created.
|
|
clone_function: callable to be applied on non-input layers in the model.
|
|
By default, it clones the layer (without copying the weights).
|
|
|
|
Returns:
|
|
An instance of `Functional` reproducing the behavior
|
|
of the original model, on top of new inputs tensors,
|
|
using newly instantiated weights.
|
|
"""
|
|
if clone_function is None:
|
|
seen = {}
|
|
|
|
def _clone_layer(layer):
|
|
if layer in seen:
|
|
return seen[layer]
|
|
new_layer = layer.__class__.from_config(layer.get_config())
|
|
seen[layer] = new_layer
|
|
return new_layer
|
|
|
|
clone_function = _clone_layer
|
|
|
|
if not callable(clone_function):
|
|
raise ValueError(
|
|
"Expected `clone_function` argument to be a callable. "
|
|
f"Received: clone_function={clone_function}"
|
|
)
|
|
|
|
if not isinstance(model, Functional):
|
|
raise ValueError(
|
|
"Expected `model` argument "
|
|
f"to be a Functional Model instance. Received: model={model}"
|
|
)
|
|
|
|
if input_tensors is not None:
|
|
input_tensors = nest.flatten(input_tensors)
|
|
if not all(isinstance(x, backend.KerasTensor) for x in input_tensors):
|
|
raise ValueError(
|
|
"All entries in `input_tensors` must be KerasTensors. "
|
|
f"Received invalid values: inputs_tensors={input_tensors}"
|
|
)
|
|
else:
|
|
input_tensors = nest.map_structure(
|
|
lambda x: Input(x.shape, dtype=x.dtype, name=x.name), model.input
|
|
)
|
|
|
|
def operation_fn(layer):
|
|
new_layer = clone_function(layer)
|
|
return new_layer
|
|
|
|
output_tensors = model._run_through_graph(
|
|
input_tensors, operation_fn=operation_fn
|
|
)
|
|
|
|
if functional_like_constructor(model.__class__):
|
|
new_model = model.__class__(
|
|
input_tensors, output_tensors, name=model.name
|
|
)
|
|
else:
|
|
# This may be incorrect: the new model will end up having a different
|
|
# class than the original. However various existing models rely
|
|
# on this behavior, so we keep it.
|
|
new_model = Functional(input_tensors, output_tensors, name=model.name)
|
|
|
|
return new_model
|