From 703d5a1298befcdcc2f87cbe3080f18fb866b63d Mon Sep 17 00:00:00 2001 From: fchollet Date: Thu, 24 Nov 2016 23:59:51 -0800 Subject: [PATCH] Add dynamic trainability lightweight test --- tests/test_dynamic_trainability.py | 114 +++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 tests/test_dynamic_trainability.py diff --git a/tests/test_dynamic_trainability.py b/tests/test_dynamic_trainability.py new file mode 100644 index 000000000..7aa5ead01 --- /dev/null +++ b/tests/test_dynamic_trainability.py @@ -0,0 +1,114 @@ +from __future__ import absolute_import +from __future__ import print_function +import pytest + +from keras.utils.test_utils import keras_test +from keras.models import Model, Sequential +from keras.layers import Dense, Input + + +@keras_test +def test_layer_trainability_switch(): + # with constructor argument, in Sequential + model = Sequential() + model.add(Dense(2, trainable=False, input_dim=1)) + assert model.trainable_weights == [] + + # by setting the `trainable` argument, in Sequential + model = Sequential() + layer = Dense(2, input_dim=1) + model.add(layer) + assert model.trainable_weights == layer.trainable_weights + layer.trainable = False + assert model.trainable_weights == [] + + # with constructor argument, in Model + x = Input(shape=(1,)) + y = Dense(2, trainable=False)(x) + model = Model(x, y) + assert model.trainable_weights == [] + + # by setting the `trainable` argument, in Model + x = Input(shape=(1,)) + layer = Dense(2) + y = layer(x) + model = Model(x, y) + assert model.trainable_weights == layer.trainable_weights + layer.trainable = False + assert model.trainable_weights == [] + + +@keras_test +def test_model_trainability_switch(): + # a non-trainable model has no trainable weights + x = Input(shape=(1,)) + y = Dense(2)(x) + model = Model(x, y) + model.trainable = False + assert model.trainable_weights == [] + + # same for Sequential + model = Sequential() + model.add(Dense(2, input_dim=1)) + model.trainable = False + assert model.trainable_weights == [] + + +@keras_test +def test_nested_model_trainability(): + # a Sequential inside a Model + inner_model = Sequential() + inner_model.add(Dense(2, input_dim=1)) + + x = Input(shape=(1,)) + y = inner_model(x) + outer_model = Model(x, y) + assert outer_model.trainable_weights == inner_model.trainable_weights + inner_model.trainable = False + assert outer_model.trainable_weights == [] + inner_model.trainable = True + inner_model.layers[-1].trainable = False + assert outer_model.trainable_weights == [] + + # a Sequential inside a Sequential + inner_model = Sequential() + inner_model.add(Dense(2, input_dim=1)) + outer_model = Sequential() + outer_model.add(inner_model) + assert outer_model.trainable_weights == inner_model.trainable_weights + inner_model.trainable = False + assert outer_model.trainable_weights == [] + inner_model.trainable = True + inner_model.layers[-1].trainable = False + assert outer_model.trainable_weights == [] + + # a Model inside a Model + x = Input(shape=(1,)) + y = Dense(2)(x) + inner_model = Model(x, y) + x = Input(shape=(1,)) + y = inner_model(x) + outer_model = Model(x, y) + assert outer_model.trainable_weights == inner_model.trainable_weights + inner_model.trainable = False + assert outer_model.trainable_weights == [] + inner_model.trainable = True + inner_model.layers[-1].trainable = False + assert outer_model.trainable_weights == [] + + # a Model inside a Sequential + x = Input(shape=(1,)) + y = Dense(2)(x) + inner_model = Model(x, y) + outer_model = Sequential() + outer_model.add(inner_model) + assert outer_model.trainable_weights == inner_model.trainable_weights + inner_model.trainable = False + assert outer_model.trainable_weights == [] + inner_model.trainable = True + inner_model.layers[-1].trainable = False + assert outer_model.trainable_weights == [] + + +if __name__ == '__main__': + pytest.main([__file__])