from __future__ import absolute_import from __future__ import print_function import pytest from keras.utils.test_utils import keras_test from keras.models import Model, Sequential from keras.layers import Dense, Input @keras_test def test_layer_trainability_switch(): # with constructor argument, in Sequential model = Sequential() model.add(Dense(2, trainable=False, input_dim=1)) assert model.trainable_weights == [] # by setting the `trainable` argument, in Sequential model = Sequential() layer = Dense(2, input_dim=1) model.add(layer) assert model.trainable_weights == layer.trainable_weights layer.trainable = False assert model.trainable_weights == [] # with constructor argument, in Model x = Input(shape=(1,)) y = Dense(2, trainable=False)(x) model = Model(x, y) assert model.trainable_weights == [] # by setting the `trainable` argument, in Model x = Input(shape=(1,)) layer = Dense(2) y = layer(x) model = Model(x, y) assert model.trainable_weights == layer.trainable_weights layer.trainable = False assert model.trainable_weights == [] @keras_test def test_model_trainability_switch(): # a non-trainable model has no trainable weights x = Input(shape=(1,)) y = Dense(2)(x) model = Model(x, y) model.trainable = False assert model.trainable_weights == [] # same for Sequential model = Sequential() model.add(Dense(2, input_dim=1)) model.trainable = False assert model.trainable_weights == [] @keras_test def test_nested_model_trainability(): # a Sequential inside a Model inner_model = Sequential() inner_model.add(Dense(2, input_dim=1)) x = Input(shape=(1,)) y = inner_model(x) outer_model = Model(x, y) assert outer_model.trainable_weights == inner_model.trainable_weights inner_model.trainable = False assert outer_model.trainable_weights == [] inner_model.trainable = True inner_model.layers[-1].trainable = False assert outer_model.trainable_weights == [] # a Sequential inside a Sequential inner_model = Sequential() inner_model.add(Dense(2, input_dim=1)) outer_model = Sequential() outer_model.add(inner_model) assert outer_model.trainable_weights == inner_model.trainable_weights inner_model.trainable = False assert outer_model.trainable_weights == [] inner_model.trainable = True inner_model.layers[-1].trainable = False assert outer_model.trainable_weights == [] # a Model inside a Model x = Input(shape=(1,)) y = Dense(2)(x) inner_model = Model(x, y) x = Input(shape=(1,)) y = inner_model(x) outer_model = Model(x, y) assert outer_model.trainable_weights == inner_model.trainable_weights inner_model.trainable = False assert outer_model.trainable_weights == [] inner_model.trainable = True inner_model.layers[-1].trainable = False assert outer_model.trainable_weights == [] # a Model inside a Sequential x = Input(shape=(1,)) y = Dense(2)(x) inner_model = Model(x, y) outer_model = Sequential() outer_model.add(inner_model) assert outer_model.trainable_weights == inner_model.trainable_weights inner_model.trainable = False assert outer_model.trainable_weights == [] inner_model.trainable = True inner_model.layers[-1].trainable = False assert outer_model.trainable_weights == [] if __name__ == '__main__': pytest.main([__file__])