keras/keras_core/saving/saving_api.py
2023-05-08 11:34:55 -07:00

131 lines
4.5 KiB
Python

import os
import zipfile
from tensorflow.io import gfile
from keras_core.api_export import keras_core_export
from keras_core.saving import saving_lib
from keras_core.saving.legacy import legacy_h5_format
try:
import h5py
except ImportError:
h5py = None
@keras_core_export(
["keras_core.saving.load_model", "keras_core.models.load_model"]
)
def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
"""Loads a model saved via `model.save()`.
Args:
filepath: `str` or `pathlib.Path` object, path to the saved model file.
custom_objects: Optional dictionary mapping names
(strings) to custom classes or functions to be
considered during deserialization.
compile: Boolean, whether to compile the model after loading.
safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization.
When `safe_mode=False`, loading an object has the potential to
trigger arbitrary code execution. This argument is only
applicable to the Keras v3 model format. Defaults to True.
Returns:
A Keras model instance. If the original model was compiled,
and the argument `compile=True` is set, then the returned model
will be compiled. Otherwise, the model will be left uncompiled.
Example:
```python
model = keras_core.Sequential([
keras_core.layers.Dense(5, input_shape=(3,)),
keras_core.layers.Softmax()])
model.save("model.keras")
loaded_model = keras_core.saving.load_model("model.keras")
x = np.random.random((10, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
```
Note that the model variables may have different name values
(`var.name` property, e.g. `"dense_1/kernel:0"`) after being reloaded.
It is recommended that you use layer attributes to
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
"""
from keras_core.saving import saving_lib
is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile(
filepath
)
# Support for remote zip files
if (
saving_lib.is_remote_path(filepath)
and not gfile.isdir(filepath)
and not is_keras_zip
):
local_path = os.path.join(
saving_lib.get_temp_dir(), os.path.basename(filepath)
)
# Copy from remote to temporary local directory
gfile.copy(filepath, local_path, overwrite=True)
# Switch filepath to local zipfile for loading model
if zipfile.is_zipfile(local_path):
filepath = local_path
is_keras_zip = True
if is_keras_zip:
return saving_lib.load_model(
filepath,
custom_objects=custom_objects,
compile=compile,
safe_mode=safe_mode,
)
else:
raise ValueError(
"The file you are trying to load does not appear to "
"be a valid `.keras` model file. If it is a legacy "
"`.h5` or SavedModel file, do note that these formats "
"are not supported in Keras Core."
)
def load_weights(model, filepath, skip_mismatch=False, **kwargs):
if str(filepath).endswith(".keras"):
if kwargs:
raise ValueError(f"Invalid keyword arguments: {kwargs}")
saving_lib.load_weights_only(
model, filepath, skip_mismatch=skip_mismatch
)
elif str(filepath).endswith(".weights.h5"):
if kwargs:
raise ValueError(f"Invalid keyword arguments: {kwargs}")
saving_lib.load_weights_only(
model, filepath, skip_mismatch=skip_mismatch
)
elif str(filepath).endswith(".h5") or str(filepath).endswith(".hdf5"):
by_name = kwargs.pop("by_name", False)
if kwargs:
raise ValueError(f"Invalid keyword arguments: {kwargs}")
if not h5py:
raise ImportError(
"Loading a H5 file requires `h5py` to be installed."
)
with h5py.File(filepath, "r") as f:
if "layer_names" not in f.attrs and "model_weights" in f:
f = f["model_weights"]
if by_name:
legacy_h5_format.load_weights_from_hdf5_group_by_name(
f, model, skip_mismatch
)
else:
legacy_h5_format.load_weights_from_hdf5_group(f, model)
else:
raise ValueError(
f"File format not supported: filepath={filepath}. "
"Keras Core only supports V3 `.keras` and `.weights.h5` "
"files."
)