"""Tests for Keras python-based idempotent saving functions.""" import json import os import warnings import zipfile from pathlib import Path from unittest import mock import numpy as np import keras_core from keras_core import operations as ops from keras_core import testing from keras_core.saving import saving_lib @keras_core.saving.register_keras_serializable(package="my_custom_package") class MyDense(keras_core.layers.Layer): def __init__(self, units, **kwargs): super().__init__(**kwargs) self.units = units self.nested_layer = keras_core.layers.Dense(self.units) def build(self, input_shape): self.additional_weights = [ self.add_weight( shape=(), name="my_additional_weight", initializer="ones", trainable=True, ), self.add_weight( shape=(), name="my_additional_weight_2", initializer="ones", trainable=True, ), ] self.weights_in_dict = { "my_weight": self.add_weight( shape=(), name="my_dict_weight", initializer="ones", trainable=True, ), } self.nested_layer.build(input_shape) def call(self, inputs): return self.nested_layer(inputs) def two(self): return 2 ASSETS_DATA = "These are my assets" VARIABLES_DATA = np.random.random((10,)) @keras_core.saving.register_keras_serializable(package="my_custom_package") class LayerWithCustomSaving(MyDense): def build(self, input_shape): self.assets = ASSETS_DATA self.stored_variables = VARIABLES_DATA return super().build(input_shape) def save_assets(self, inner_path): with open(os.path.join(inner_path, "assets.txt"), "w") as f: f.write(self.assets) def save_own_variables(self, store): store["variables"] = self.stored_variables def load_assets(self, inner_path): with open(os.path.join(inner_path, "assets.txt"), "r") as f: text = f.read() self.assets = text def load_own_variables(self, store): self.stored_variables = np.array(store["variables"]) @keras_core.saving.register_keras_serializable(package="my_custom_package") class CustomModelX(keras_core.Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.dense1 = MyDense(1) self.dense2 = MyDense(1) def call(self, inputs): out = self.dense1(inputs) return self.dense2(out) def one(self): return 1 @keras_core.saving.register_keras_serializable(package="my_custom_package") class ModelWithCustomSaving(keras_core.Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.custom_dense = LayerWithCustomSaving(1) def call(self, inputs): return self.custom_dense(inputs) @keras_core.saving.register_keras_serializable(package="my_custom_package") class CompileOverridingModel(keras_core.Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.dense1 = MyDense(1) def compile(self, *args, **kwargs): super().compile(*args, **kwargs) def call(self, inputs): return self.dense1(inputs) @keras_core.saving.register_keras_serializable(package="my_custom_package") class CompileOverridingSequential(keras_core.Sequential): def compile(self, *args, **kwargs): super().compile(*args, **kwargs) @keras_core.saving.register_keras_serializable(package="my_custom_package") def my_mean_squared_error(y_true, y_pred): """Identical to built-in `mean_squared_error`, but as a custom fn.""" return ops.mean(ops.square(y_pred - y_true), axis=-1) class SavingTest(testing.TestCase): def _get_subclassed_model(self, compile=True): subclassed_model = CustomModelX() if compile: subclassed_model.compile( optimizer="adam", loss=my_mean_squared_error, metrics=[keras_core.metrics.Hinge(), "mse"], ) return subclassed_model def _get_custom_sequential_model(self, compile=True): sequential_model = keras_core.Sequential([MyDense(1), MyDense(1)]) if compile: sequential_model.compile( optimizer="adam", loss=my_mean_squared_error, metrics=[keras_core.metrics.Hinge(), "mse"], ) return sequential_model def _get_basic_sequential_model(self, compile=True): sequential_model = keras_core.Sequential( [keras_core.layers.Dense(1), keras_core.layers.Dense(1)] ) if compile: sequential_model.compile( optimizer="adam", loss=my_mean_squared_error, metrics=[keras_core.metrics.Hinge(), "mse"], ) return sequential_model def _get_custom_functional_model(self, compile=True): inputs = keras_core.Input(shape=(4,), batch_size=2) x = MyDense(1, name="first_dense")(inputs) outputs = MyDense(1, name="second_dense")(x) functional_model = keras_core.Model(inputs, outputs) if compile: functional_model.compile( optimizer="adam", loss=my_mean_squared_error, metrics=[keras_core.metrics.Hinge(), "mse"], ) return functional_model def _get_basic_functional_model(self, compile=True): inputs = keras_core.Input(shape=(4,), batch_size=2) x = keras_core.layers.Dense(1, name="first_dense")(inputs) outputs = keras_core.layers.Dense(1, name="second_dense")(x) functional_model = keras_core.Model(inputs, outputs) if compile: functional_model.compile( optimizer="adam", loss=my_mean_squared_error, metrics=[keras_core.metrics.Hinge(), "mse"], ) return functional_model def _test_inference_after_instantiation(self, model): x_ref = np.random.random((2, 4)) y_ref = model(x_ref) temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras") model.save(temp_filepath) loaded_model = saving_lib.load_model(temp_filepath) self.assertFalse(model.compiled) for w_ref, w in zip(model.variables, loaded_model.variables): self.assertAllClose(w_ref, w) self.assertAllClose(y_ref, loaded_model(x_ref)) def test_inference_after_instantiation_subclassed(self): model = self._get_subclassed_model(compile=False) self._test_inference_after_instantiation(model) def test_inference_after_instantiation_basic_sequential(self): model = self._get_basic_sequential_model(compile=False) self._test_inference_after_instantiation(model) def test_inference_after_instantiation_basic_functional(self): model = self._get_basic_functional_model(compile=False) self._test_inference_after_instantiation(model) def test_inference_after_instantiation_custom_sequential(self): model = self._get_custom_sequential_model(compile=False) self._test_inference_after_instantiation(model) def test_inference_after_instantiation_custom_functional(self): model = self._get_custom_functional_model(compile=False) self._test_inference_after_instantiation(model) def _test_compile_preserved(self, model): x_ref = np.random.random((2, 4)) y_ref = np.random.random((2, 1)) model.fit(x_ref, y_ref) out_ref = model(x_ref) ref_metrics = model.evaluate(x_ref, y_ref) temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras") model.save(temp_filepath) loaded_model = saving_lib.load_model(temp_filepath) self.assertTrue(model.compiled) self.assertTrue(loaded_model.built) for w_ref, w in zip(model.variables, loaded_model.variables): self.assertAllClose(w_ref, w) self.assertAllClose(out_ref, loaded_model(x_ref)) self.assertEqual( model.optimizer.__class__, loaded_model.optimizer.__class__ ) self.assertEqual( model.optimizer.get_config(), loaded_model.optimizer.get_config() ) for w_ref, w in zip( model.optimizer.variables, loaded_model.optimizer.variables ): self.assertAllClose(w_ref, w) new_metrics = loaded_model.evaluate(x_ref, y_ref) for ref_m, m in zip(ref_metrics, new_metrics): self.assertAllClose(ref_m, m) def test_compile_preserved_subclassed(self): model = self._get_subclassed_model(compile=True) self._test_compile_preserved(model) def test_compile_preserved_basic_sequential(self): model = self._get_basic_sequential_model(compile=True) self._test_compile_preserved(model) def test_compile_preserved_custom_sequential(self): model = self._get_custom_sequential_model(compile=True) self._test_compile_preserved(model) def test_compile_preserved_basic_functional(self): model = self._get_basic_functional_model(compile=True) self._test_compile_preserved(model) def test_compile_preserved_custom_functional(self): model = self._get_custom_functional_model(compile=True) self._test_compile_preserved(model) def test_saving_preserve_unbuilt_state(self): temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras") subclassed_model = CustomModelX() subclassed_model.save(temp_filepath) loaded_model = saving_lib.load_model(temp_filepath) self.assertEqual(subclassed_model.compiled, loaded_model.compiled) self.assertFalse(subclassed_model.built) self.assertFalse(loaded_model.built) def test_saved_module_paths_and_class_names(self): temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras") subclassed_model = self._get_subclassed_model() x = np.random.random((100, 32)) y = np.random.random((100, 1)) subclassed_model.fit(x, y, epochs=1) subclassed_model.save(temp_filepath) with zipfile.ZipFile(temp_filepath, "r") as z: with z.open(saving_lib._CONFIG_FILENAME, "r") as c: config_json = c.read() config_dict = json.loads(config_json) self.assertEqual( config_dict["registered_name"], "my_custom_package>CustomModelX" ) self.assertEqual( config_dict["compile_config"]["optimizer"], "adam", ) print(config_dict["compile_config"]) self.assertEqual( config_dict["compile_config"]["loss"]["config"], "my_mean_squared_error", ) def test_saving_custom_assets_and_variables(self): temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras") model = ModelWithCustomSaving() model.compile( optimizer="adam", loss="mse", ) x = np.random.random((100, 32)) y = np.random.random((100, 1)) model.fit(x, y, epochs=1) # Assert that the archive has not been saved. self.assertFalse(os.path.exists(temp_filepath)) model.save(temp_filepath) loaded_model = saving_lib.load_model(temp_filepath) self.assertEqual(loaded_model.custom_dense.assets, ASSETS_DATA) self.assertEqual( loaded_model.custom_dense.stored_variables.tolist(), VARIABLES_DATA.tolist(), ) def _test_compile_overridden_warnings(self, model_type): temp_filepath = os.path.join(self.get_temp_dir(), "my_model.keras") model = ( CompileOverridingModel() if model_type == "subclassed" else CompileOverridingSequential( [keras_core.layers.Embedding(4, 1), MyDense(1), MyDense(1)] ) ) model.compile("sgd", "mse") model.save(temp_filepath) with mock.patch.object(warnings, "warn") as mock_warn: saving_lib.load_model(temp_filepath) if not mock_warn.call_args_list: raise AssertionError("Did not warn.") self.assertIn( "`compile()` was not called as part of model loading " "because the model's `compile()` method is custom. ", mock_warn.call_args_list[0][0][0], ) # TODO: need Embedding layer # def test_compile_overridden_warnings_sequential(self): # self._test_compile_overridden_warnings("sequential") # def test_compile_overridden_warnings_subclassed(self): # self._test_compile_overridden_warnings("subclassed") def test_metadata(self): temp_filepath = Path( os.path.join(self.get_temp_dir(), "my_model.keras") ) model = CompileOverridingModel() model.save(temp_filepath) with zipfile.ZipFile(temp_filepath, "r") as z: with z.open(saving_lib._METADATA_FILENAME, "r") as c: metadata_json = c.read() metadata = json.loads(metadata_json) self.assertIn("keras_version", metadata) self.assertIn("date_saved", metadata) # def test_gfile_copy_local_called(self): # temp_filepath = Path( # os.path.join(self.get_temp_dir(), "my_model.keras") # ) # model = CompileOverridingModel() # with mock.patch("re.match", autospec=True) as mock_re_match, mock.patch( # "tensorflow.compat.v2.io.gfile.copy", autospec=True # ) as mock_copy: # # Mock Remote Path check to true to test gfile copy logic # mock_re_match.return_value = True # model.save(temp_filepath) # mock_re_match.assert_called() # mock_copy.assert_called() # self.assertIn(str(temp_filepath), mock_re_match.call_args.args) # self.assertIn(str(temp_filepath), mock_copy.call_args.args) # def test_load_model_api_endpoint(self): # temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.keras")) # model = self._get_functional_model() # ref_input = np.random.random((10, 32)) # ref_output = model.predict(ref_input) # model.save(temp_filepath) # model = keras_core.models.load_model(temp_filepath) # self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) def test_save_load_weights_only(self): temp_filepath = Path( os.path.join(self.get_temp_dir(), "mymodel.weights.h5") ) model = self._get_basic_functional_model() ref_input = np.random.random((2, 4)) ref_output = model.predict(ref_input) saving_lib.save_weights_only(model, temp_filepath) model = self._get_basic_functional_model() saving_lib.load_weights_only(model, temp_filepath) self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) # Test with Model method model = self._get_basic_functional_model() model.load_weights(temp_filepath) self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) def test_load_weights_only_with_keras_file(self): # Test loading weights from whole saved model temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.keras")) model = self._get_basic_functional_model() ref_input = np.random.random((2, 4)) ref_output = model.predict(ref_input) saving_lib.save_model(model, temp_filepath) model = self._get_basic_functional_model() saving_lib.load_weights_only(model, temp_filepath) self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) # Test with Model method model = self._get_basic_functional_model() model.load_weights(temp_filepath) self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) def test_compile_arg(self): temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras") model = self._get_basic_functional_model() model.compile("sgd", "mse") model.fit(np.random.random((2, 4)), np.random.random((2, 1))) saving_lib.save_model(model, temp_filepath) model = saving_lib.load_model(temp_filepath) self.assertEqual(model.compiled, True) model = saving_lib.load_model(temp_filepath, compile=False) self.assertEqual(model.compiled, False) # def test_overwrite(self): # temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras") # model = self._get_basic_functional_model() # model.save(temp_filepath) # model.save(temp_filepath, overwrite=True) # with self.assertRaises(EOFError): # model.save(temp_filepath, overwrite=False) # temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.weights.h5") # model = self._get_basic_functional_model() # model.save_weights(temp_filepath) # model.save_weights(temp_filepath, overwrite=True) # with self.assertRaises(EOFError): # model.save_weights(temp_filepath, overwrite=False) def test_partial_load(self): temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras") original_model = keras_core.Sequential( [ keras_core.Input(shape=(3,), batch_size=2), keras_core.layers.Dense(4), keras_core.layers.Dense(5), ] ) original_model.save(temp_filepath) # Test with a model that has a differently shaped layer new_model = keras_core.Sequential( [ keras_core.Input(shape=(3,), batch_size=2), keras_core.layers.Dense(4), keras_core.layers.Dense(6), ] ) new_layer_kernel_value = np.array(new_model.layers[1].kernel) with self.assertRaisesRegex(ValueError, "must match"): # Doesn't work by default new_model.load_weights(temp_filepath) # Now it works new_model.load_weights(temp_filepath, skip_mismatch=True) ref_weights = original_model.layers[0].get_weights() new_weights = new_model.layers[0].get_weights() self.assertEqual(len(ref_weights), len(new_weights)) for ref_w, w in zip(ref_weights, new_weights): self.assertAllClose(ref_w, w) self.assertAllClose( np.array(new_model.layers[1].kernel), new_layer_kernel_value ) # Test with a model that has a new layer at the end new_model = keras_core.Sequential( [ keras_core.Input(shape=(3,), batch_size=2), keras_core.layers.Dense(4), keras_core.layers.Dense(5), keras_core.layers.Dense(5), ] ) new_layer_kernel_value = np.array(new_model.layers[2].kernel) with self.assertRaisesRegex(ValueError, "received 0 variables"): # Doesn't work by default new_model.load_weights(temp_filepath) # Now it works new_model.load_weights(temp_filepath, skip_mismatch=True) for layer_index in [0, 1]: ref_weights = original_model.layers[layer_index].get_weights() new_weights = new_model.layers[layer_index].get_weights() self.assertEqual(len(ref_weights), len(new_weights)) for ref_w, w in zip(ref_weights, new_weights): self.assertAllClose(ref_w, w) self.assertAllClose( np.array(new_model.layers[2].kernel), new_layer_kernel_value ) # def test_safe_mode(self): # temp_filepath = os.path.join(self.get_temp_dir(), "unsafe_model.keras") # model = keras_core.Sequential( # [ # keras_core.Input(shape=(3,)), # keras_core.layers.Dense(2, activation=lambda x: x * 2), # ] # ) # model.save(temp_filepath) # with self.assertRaisesRegex(ValueError, "arbitrary code execution"): # model = saving_lib.load_model(temp_filepath) # model = saving_lib.load_model(temp_filepath, safe_mode=False) # def test_normalization_kpl(self): # # With adapt # temp_filepath = os.path.join(self.get_temp_dir(), "norm_model.keras") # model = keras_core.Sequential( # [ # keras_core.Input(shape=(3,)), # keras_core.layers.Normalization(), # ] # ) # data = np.random.random((3, 3)) # model.layers[0].adapt(data) # ref_out = model(data) # model.save(temp_filepath) # model = saving_lib.load_model(temp_filepath) # out = model(data) # self.assertAllClose(ref_out, out, atol=1e-6) # # Without adapt # model = keras_core.Sequential( # [ # keras_core.Input(shape=(3,)), # keras_core.layers.Normalization( # mean=np.random.random((3,)), variance=np.random.random((3,)) # ), # ] # ) # ref_out = model(data) # model.save(temp_filepath) # model = saving_lib.load_model(temp_filepath) # out = model(data) # self.assertAllClose(ref_out, out, atol=1e-6) # # This custom class lacks custom object registration. # class CustomRNN(keras_core.layers.Layer): # def __init__(self, units): # super(CustomRNN, self).__init__() # self.units = units # self.projection_1 = keras_core.layers.Dense(units=units, activation="tanh") # self.projection_2 = keras_core.layers.Dense(units=units, activation="tanh") # self.classifier = keras_core.layers.Dense(1) # def call(self, inputs): # outputs = [] # state = ops.zeros(shape=(inputs.shape[0], self.units)) # for t in range(inputs.shape[1]): # x = inputs[:, t, :] # h = self.projection_1(x) # y = h + self.projection_2(state) # state = y # outputs.append(y) # features = ops.stack(outputs, axis=1) # return self.classifier(features) # # This class is properly registered with a `get_config()` method. # # However, since it does not subclass keras_core.layers.Layer, it lacks # # `from_config()` for deserialization. # @keras_core.saving.register_keras_serializable() # class GrowthFactor: # def __init__(self, factor): # self.factor = factor # def __call__(self, inputs): # return inputs * self.factor # def get_config(self): # return {"factor": self.factor} # @keras_core.saving.register_keras_serializable(package="Complex") # class FactorLayer(keras_core.layers.Layer): # def __init__(self, factor): # super().__init__() # self.factor = factor # def call(self, x): # return x * self.factor # def get_config(self): # return {"factor": self.factor} # # This custom model does not explicitly deserialize the layers it includes # # in its `get_config`. Explicit deserialization in a `from_config` override # # or `__init__` is needed here, or an error will be thrown at loading time. # @keras_core.saving.register_keras_serializable(package="Complex") # class ComplexModel(keras_core.layers.Layer): # def __init__(self, first_layer, second_layer=None, **kwargs): # super().__init__(**kwargs) # self.first_layer = first_layer # if second_layer is not None: # self.second_layer = second_layer # else: # self.second_layer = keras_core.layers.Dense(8) # def get_config(self): # config = super().get_config() # config.update( # { # "first_layer": self.first_layer, # "second_layer": self.second_layer, # } # ) # return config # def call(self, inputs): # return self.first_layer(self.second_layer(inputs)) # class SavingBattleTest(testing.TestCase): # def test_custom_model_without_registration_error(self): # temp_filepath = os.path.join( # self.get_temp_dir(), "my_custom_model.keras" # ) # timesteps = 10 # input_dim = 5 # batch_size = 16 # inputs = keras_core.Input(batch_shape=(batch_size, timesteps, input_dim)) # x = keras_core.layers.Conv1D(32, 3)(inputs) # outputs = CustomRNN(32)(x) # model = keras_core.Model(inputs, outputs) # with self.assertRaisesRegex( # TypeError, "is a custom class, please register it" # ): # model.save(temp_filepath) # _ = keras_core.models.load_model(temp_filepath) # def test_custom_object_without_from_config(self): # temp_filepath = os.path.join( # self.get_temp_dir(), "custom_fn_model.keras" # ) # inputs = keras_core.Input(shape=(4, 4)) # outputs = keras_core.layers.Dense(1, activation=GrowthFactor(0.5))(inputs) # model = keras_core.Model(inputs, outputs) # model.save(temp_filepath) # with self.assertRaisesRegex( # TypeError, "Unable to reconstruct an instance" # ): # _ = keras_core.models.load_model(temp_filepath) # def test_complex_model_without_explicit_deserialization(self): # temp_filepath = os.path.join(self.get_temp_dir(), "complex_model.keras") # inputs = keras_core.Input((32,)) # outputs = ComplexModel(first_layer=FactorLayer(0.5))(inputs) # model = keras_core.Model(inputs, outputs) # model.save(temp_filepath) # with self.assertRaisesRegex(TypeError, "are explicitly deserialized"): # _ = keras_core.models.load_model(temp_filepath)