"""Python-based idempotent model-saving functionality.""" import datetime import io import json import os import re import tempfile import warnings import zipfile import numpy as np from tensorflow.io import gfile from keras_core.layers.layer import Layer from keras_core.losses.loss import Loss from keras_core.metrics.metric import Metric from keras_core.optimizers.optimizer import Optimizer from keras_core.saving.serialization_lib import ObjectSharingScope from keras_core.saving.serialization_lib import deserialize_keras_object from keras_core.saving.serialization_lib import serialize_keras_object from keras_core.utils import naming keras_version = "0.0.0" # TODO try: import h5py except ImportError: h5py = None _CONFIG_FILENAME = "config.json" _METADATA_FILENAME = "metadata.json" _VARS_FNAME = "model.weights" # Will become e.g. "model.weights.h5" _ASSETS_DIRNAME = "assets" ATTR_SKIPLIST = frozenset( { "_operations", "_layers", "_functional", "_losses", "_inbound_nodes", "_outbound_nodes", "_variables", "weights", "non_trainable_weights", "trainable_weights", "variables", "non_trainable_variables", "trainable_variables", } ) def save_model(model, filepath, weights_format="h5"): """Save a zip-archive representing a Keras model to the given filepath. The zip-based archive contains the following structure: - JSON-based configuration file (config.json): Records of model, layer, and other trackables' configuration. - NPZ-based trackable state files, found in respective directories, such as model/states.npz, model/dense_layer/states.npz, etc. - Metadata file. The states of Keras trackables (layers, optimizers, loss, and metrics) are automatically saved as long as they can be discovered through the attributes returned by `dir(Model)`. Typically, the state includes the variables associated with the trackable, but some specially purposed layers may contain more such as the vocabularies stored in the hashmaps. The trackables define how their states are saved by exposing `save_state()` and `load_state()` APIs. For the case of layer states, the variables will be visited as long as they are either 1) referenced via layer attributes, or 2) referenced via a container (list, tuple, or dict), and the container is referenced via a layer attribute. """ filepath = str(filepath) if not filepath.endswith(".keras"): raise ValueError( "Invalid `filepath` argument: expected a `.keras` extension. " f"Received: filepath={filepath}" ) if weights_format == "h5" and h5py is None: raise ImportError("h5py must be installed in order to save a model.") if not model.built: warnings.warn( "You are saving a model that has not yet been built. " "It might not contain any weights yet. " "Consider building the model first by calling it " "on some data.", stacklevel=2, ) with ObjectSharingScope(): serialized_model_dict = serialize_keras_object(model) config_json = json.dumps(serialized_model_dict) metadata_json = json.dumps( { "keras_version": keras_version, "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"), } ) # TODO(rameshsampath): Need a better logic for local vs remote path if is_remote_path(filepath): # Remote path. Zip to local drive and copy to remote zip_filepath = os.path.join(get_temp_dir(), "tmp_model.keras") else: zip_filepath = filepath with zipfile.ZipFile(zip_filepath, "w") as zf: with zf.open(_METADATA_FILENAME, "w") as f: f.write(metadata_json.encode()) with zf.open(_CONFIG_FILENAME, "w") as f: f.write(config_json.encode()) if weights_format == "h5": weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="w") elif weights_format == "npz": weights_store = NpzIOStore( _VARS_FNAME + ".npz", archive=zf, mode="w" ) else: raise ValueError( "Unknown `weights_format` argument. " "Expected 'h5' or 'npz'. " f"Received: weights_format={weights_format}" ) asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="w") _save_state( model, weights_store=weights_store, assets_store=asset_store, inner_path="", visited_trackables=set(), ) weights_store.close() asset_store.close() if is_remote_path(filepath): # Using gfile context manager doesn't close zip file when # writing to GCS. Hence writing to local and copying to filepath. gfile.copy(zip_filepath, filepath, overwrite=True) os.remove(zip_filepath) def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): """Load a zip archive representing a Keras model.""" filepath = str(filepath) if not filepath.endswith(".keras"): raise ValueError( "Invalid filename: expected a `.keras` extension. " f"Received: filepath={filepath}" ) with gfile.GFile(filepath, mode="r+b") as gfile_handle, zipfile.ZipFile( gfile_handle, "r" ) as zf: with zf.open(_CONFIG_FILENAME, "r") as f: config_json = f.read() # Note: we should NOT use a custom JSON decoder. Anything that # needs custom decoding must be handled in deserialize_keras_object. config_dict = json.loads(config_json) if not compile: # Disable compilation config_dict["compile_config"] = None # Construct the model from the configuration file in the archive. with ObjectSharingScope(): model = deserialize_keras_object( config_dict, custom_objects, safe_mode=safe_mode ) all_filenames = zf.namelist() if _VARS_FNAME + ".h5" in all_filenames: weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="r") elif _VARS_FNAME + ".npz" in all_filenames: weights_store = NpzIOStore( _VARS_FNAME + ".npz", archive=zf, mode="r" ) else: raise ValueError( f"Expected a {_VARS_FNAME}.h5 or {_VARS_FNAME}.npz file." ) if len(all_filenames) > 3: asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="r") else: asset_store = None _load_state( model, weights_store=weights_store, assets_store=asset_store, inner_path="", visited_trackables=set(), ) weights_store.close() if asset_store: asset_store.close() return model def save_weights_only(model, filepath): """Save only the weights of a model to a target filepath (.weights.h5). Note: only supports h5 for now. """ # TODO: if h5 filepath is remote, create the file in a temporary directory # then upload it filepath = str(filepath) if not filepath.endswith(".weights.h5"): raise ValueError( "Invalid `filepath` argument: expected a `.weights.h5` extension. " f"Received: filepath={filepath}" ) weights_store = H5IOStore(filepath, mode="w") _save_state( model, weights_store=weights_store, assets_store=None, inner_path="", visited_trackables=set(), ) weights_store.close() def load_weights_only(model, filepath, skip_mismatch=False): """Load the weights of a model from a filepath (.keras or .weights.h5). Note: only supports h5 for now. """ temp_dir = None archive = None filepath = str(filepath) if filepath.endswith(".weights.h5"): # TODO: download file if h5 filepath is remote weights_store = H5IOStore(filepath, mode="r") elif filepath.endswith(".keras"): archive = zipfile.ZipFile(filepath, "r") weights_store = H5IOStore( _VARS_FNAME + ".h5", archive=archive, mode="r" ) _load_state( model, weights_store=weights_store, assets_store=None, inner_path="", skip_mismatch=skip_mismatch, visited_trackables=set(), ) weights_store.close() if temp_dir and gfile.exists(temp_dir): gfile.rmtree(temp_dir) if archive: archive.close() def is_remote_path(filepath): if re.match(r"^(/cns|/cfs|/gcs|.*://).*$", str(filepath)): return True return False def _write_to_zip_recursively(zipfile_to_save, system_path, zip_path): if not gfile.isdir(system_path): zipfile_to_save.write(system_path, zip_path) else: for file_name in gfile.listdir(system_path): system_file_path = gfile.join(system_path, file_name) zip_file_path = gfile.join(zip_path, file_name) _write_to_zip_recursively( zipfile_to_save, system_file_path, zip_file_path ) def _walk_trackable(trackable): for child_attr in dir(trackable): if child_attr.startswith("__") or child_attr in ATTR_SKIPLIST: continue try: child_obj = getattr(trackable, child_attr) except Exception: # Avoid raising the exception when visiting the attributes. continue yield child_attr, child_obj def _save_state( trackable, weights_store, assets_store, inner_path, visited_trackables ): # If the trackable has already been saved, skip it. if id(trackable) in visited_trackables: return if hasattr(trackable, "save_own_variables") and weights_store: trackable.save_own_variables(weights_store.make(inner_path)) if hasattr(trackable, "save_assets") and assets_store: trackable.save_assets(assets_store.make(inner_path)) visited_trackables.add(id(trackable)) # Recursively save state of children trackables (layers, optimizers, etc.) for child_attr, child_obj in _walk_trackable(trackable): if _is_keras_trackable(child_obj): _save_state( child_obj, weights_store, assets_store, inner_path=gfile.join(inner_path, child_attr), visited_trackables=visited_trackables, ) elif isinstance(child_obj, (list, dict, tuple, set)): _save_container_state( child_obj, weights_store, assets_store, inner_path=gfile.join(inner_path, child_attr), visited_trackables=visited_trackables, ) def _load_state( trackable, weights_store, assets_store, inner_path, skip_mismatch=False, visited_trackables=None, ): if visited_trackables and id(trackable) in visited_trackables: return if hasattr(trackable, "load_own_variables") and weights_store: if skip_mismatch: try: trackable.load_own_variables(weights_store.get(inner_path)) except Exception as e: warnings.warn( f"Could not load weights in object {trackable}. " "Skipping object. " f"Exception encountered: {e}", stacklevel=2, ) else: trackable.load_own_variables(weights_store.get(inner_path)) if hasattr(trackable, "load_assets") and assets_store: if skip_mismatch: try: trackable.load_assets(assets_store.get(inner_path)) except Exception as e: warnings.warn( f"Could not load assets in object {trackable}. " "Skipping object. " f"Exception encountered: {e}", stacklevel=2, ) else: trackable.load_assets(assets_store.get(inner_path)) if visited_trackables is not None: visited_trackables.add(id(trackable)) # Recursively load states for Keras trackables such as layers/optimizers. for child_attr, child_obj in _walk_trackable(trackable): if _is_keras_trackable(child_obj): _load_state( child_obj, weights_store, assets_store, inner_path=gfile.join(inner_path, child_attr), skip_mismatch=skip_mismatch, visited_trackables=visited_trackables, ) elif isinstance(child_obj, (list, dict, tuple, set)): _load_container_state( child_obj, weights_store, assets_store, inner_path=gfile.join(inner_path, child_attr), skip_mismatch=skip_mismatch, visited_trackables=visited_trackables, ) def _save_container_state( container, weights_store, assets_store, inner_path, visited_trackables ): used_names = {} if isinstance(container, dict): container = list(container.values()) for trackable in container: if _is_keras_trackable(trackable): # Do NOT address the trackable via `trackable.name`, since # names are usually autogenerated and thus not reproducible # (i.e. they may vary across two instances of the same model). name = naming.to_snake_case(trackable.__class__.__name__) if name in used_names: used_names[name] += 1 name = f"{name}_{used_names[name]}" else: used_names[name] = 0 _save_state( trackable, weights_store, assets_store, inner_path=gfile.join(inner_path, name), visited_trackables=visited_trackables, ) def _load_container_state( container, weights_store, assets_store, inner_path, skip_mismatch, visited_trackables, ): used_names = {} if isinstance(container, dict): container = list(container.values()) for trackable in container: if _is_keras_trackable(trackable): name = naming.to_snake_case(trackable.__class__.__name__) if name in used_names: used_names[name] += 1 name = f"{name}_{used_names[name]}" else: used_names[name] = 0 _load_state( trackable, weights_store, assets_store, inner_path=gfile.join(inner_path, name), skip_mismatch=skip_mismatch, visited_trackables=visited_trackables, ) class DiskIOStore: """Asset store backed by disk storage. If `archive` is specified, then `root_path` refers to the filename inside the archive. If `archive` is not specified, then `root_path` refers to the full path of the target directory. """ def __init__(self, root_path, archive=None, mode=None): self.mode = mode self.root_path = root_path self.archive = archive self.tmp_dir = None if self.archive: self.tmp_dir = get_temp_dir() if self.mode == "r": self.archive.extractall(path=self.tmp_dir) self.working_dir = gfile.join(self.tmp_dir, self.root_path) if self.mode == "w": gfile.makedirs(self.working_dir) else: if mode == "r": self.working_dir = root_path else: self.tmp_dir = get_temp_dir() self.working_dir = gfile.join(self.tmp_dir, self.root_path) gfile.makedirs(self.working_dir) def make(self, path): if not path: return self.working_dir path = gfile.join(self.working_dir, path) if not gfile.exists(path): gfile.makedirs(path) return path def get(self, path): if not path: return self.working_dir path = gfile.join(self.working_dir, path) if gfile.exists(path): return path return None def close(self): if self.mode == "w" and self.archive: _write_to_zip_recursively( self.archive, self.working_dir, self.root_path ) if self.tmp_dir and gfile.exists(self.tmp_dir): gfile.rmtree(self.tmp_dir) class H5IOStore: def __init__(self, root_path, archive=None, mode="r"): """Numerical variable store backed by HDF5. If `archive` is specified, then `root_path` refers to the filename inside the archive. If `archive` is not specified, then `root_path` refers to the path of the h5 file on disk. """ self.root_path = root_path self.mode = mode self.archive = archive self.io_file = None if self.archive: if self.mode == "w": self.io_file = io.BytesIO() else: self.io_file = self.archive.open(self.root_path, "r") self.h5_file = h5py.File(self.io_file, mode=self.mode) else: self.h5_file = h5py.File(root_path, mode=self.mode) def make(self, path): if not path: return self.h5_file.create_group("vars") return self.h5_file.create_group(path).create_group("vars") def get(self, path): if not path: return self.h5_file["vars"] if path in self.h5_file and "vars" in self.h5_file[path]: return self.h5_file[path]["vars"] return {} def close(self): self.h5_file.close() if self.mode == "w" and self.archive: self.archive.writestr(self.root_path, self.io_file.getvalue()) if self.io_file: self.io_file.close() class NpzIOStore: def __init__(self, root_path, archive=None, mode="r"): """Numerical variable store backed by NumPy.savez/load. If `archive` is specified, then `root_path` refers to the filename inside the archive. If `archive` is not specified, then `root_path` refers to the path of the npz file on disk. """ self.root_path = root_path self.mode = mode self.archive = archive if mode == "w": self.contents = {} else: if self.archive: self.f = archive.open(root_path, mode="r") else: self.f = open(root_path, mode="rb") self.contents = np.load(self.f, allow_pickle=True) def make(self, path): if not path: self.contents["__root__"] = {} return self.contents["__root__"] self.contents[path] = {} return self.contents[path] def get(self, path): if not path: if "__root__" in self.contents: return dict(self.contents["__root__"]) return {} if path in self.contents: return self.contents[path].tolist() return {} def close(self): if self.mode == "w": if self.archive: self.f = self.archive.open( self.root_path, mode="w", force_zip64=True ) else: self.f = open(self.root_path, mode="wb") np.savez(self.f, **self.contents) self.f.close() def get_temp_dir(): temp_dir = tempfile.mkdtemp() testfile = tempfile.TemporaryFile(dir=temp_dir) testfile.close() return temp_dir def _is_keras_trackable(obj): return isinstance( obj, ( Layer, Optimizer, Metric, Loss, ), )