Chen Qian 6376740aca Fix layer test in Torch backend (#360)
* init

* Fix dropout ops

* Small fixes

* fix norm layers

* fix tets

* more fixes

* clean up

* fix

* fix format

* fix

* fix comments

* fix comments

* remove redundant copy

* revert jax change to dodge the odd recursion issue
2023-06-15 18:45:51 -07:00

766 lines
27 KiB

"""Object config serialization and deserialization logic."""
import importlib
import inspect
import types
import warnings
import jax
import numpy as np
import tensorflow as tf
from keras_core import api_export
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.backend.common import global_state
from keras_core.saving import object_registration
from keras_core.utils import python_utils
PLAIN_TYPES = (str, int, float, bool)
# List of Keras modules with built-in string representations for Keras defaults
class SerializableDict:
def __init__(self, **config):
self.config = config
def serialize(self):
return serialize_keras_object(self.config)
class SafeModeScope:
"""Scope to propagate safe mode flag to nested deserialization calls."""
def __init__(self, safe_mode=True):
self.safe_mode = safe_mode
def __enter__(self):
self.original_value = in_safe_mode()
global_state.set_global_setting("safe_mode_saving", self.safe_mode)
def __exit__(self, *args, **kwargs):
global_state.set_global_setting("safe_mode_saving", self.original_value)
def enable_unsafe_deserialization():
"""Disables safe mode globally, allowing deserialization of lambdas."""
global_state.set_global_setting("safe_mode_saving", False)
def in_safe_mode():
return global_state.get_global_setting("safe_mode_saving")
class ObjectSharingScope:
"""Scope to enable detection and reuse of previously seen objects."""
def __enter__(self):
global_state.set_global_attribute("shared_objects/id_to_obj_map", {})
global_state.set_global_attribute("shared_objects/id_to_config_map", {})
def __exit__(self, *args, **kwargs):
global_state.set_global_attribute("shared_objects/id_to_obj_map", None)
"shared_objects/id_to_config_map", None
def get_shared_object(obj_id):
"""Retrieve an object previously seen during deserialization."""
id_to_obj_map = global_state.get_global_attribute(
if id_to_obj_map is not None:
return id_to_obj_map.get(obj_id, None)
def record_object_after_serialization(obj, config):
"""Call after serializing an object, to keep track of its config."""
if config["module"] == "__main__":
config["module"] = None # Ensures module is None when no module found
id_to_config_map = global_state.get_global_attribute(
if id_to_config_map is None:
return # Not in a sharing scope
obj_id = int(id(obj))
if obj_id not in id_to_config_map:
id_to_config_map[obj_id] = config
config["shared_object_id"] = obj_id
prev_config = id_to_config_map[obj_id]
prev_config["shared_object_id"] = obj_id
def record_object_after_deserialization(obj, obj_id):
"""Call after deserializing an object, to keep track of it in the future."""
id_to_obj_map = global_state.get_global_attribute(
if id_to_obj_map is None:
return # Not in a sharing scope
id_to_obj_map[obj_id] = obj
def serialize_keras_object(obj):
"""Retrieve the config dict by serializing the Keras object.
`serialize_keras_object()` serializes a Keras object to a python dictionary
that represents the object, and is a reciprocal function of
`deserialize_keras_object()`. See `deserialize_keras_object()` for more
information about the config format.
obj: the Keras object to serialize.
A python dict that represents the object. The python dict can be
deserialized via `deserialize_keras_object()`.
if obj is None:
return obj
if isinstance(obj, PLAIN_TYPES):
return obj
if isinstance(obj, (list, tuple)):
config_arr = [serialize_keras_object(x) for x in obj]
return tuple(config_arr) if isinstance(obj, tuple) else config_arr
if isinstance(obj, dict):
return serialize_dict(obj)
# Special cases:
if isinstance(obj, bytes):
return {
"class_name": "__bytes__",
"config": {"value": obj.decode("utf-8")},
if isinstance(obj, backend.KerasTensor):
history = getattr(obj, "_keras_history", None)
if history:
history = list(history)
history[0] = history[0].name
return {
"class_name": "__keras_tensor__",
"config": {
"shape": obj.shape,
"dtype": obj.dtype,
"keras_history": history,
if isinstance(obj, tf.TensorShape):
return obj.as_list() if obj._dims is not None else None
if isinstance(obj, (tf.Tensor, jax.numpy.ndarray)) or hasattr(
obj, "device"
# Import torch creates circular dependency, so we use
# `hasattr(obj, "device")` to check if obj is a torch tensor.
return {
"class_name": "__tensor__",
"config": {
"value": backend.convert_to_numpy(obj).tolist(),
"dtype": backend.standardize_dtype(obj.dtype),
if type(obj).__module__ == np.__name__:
if isinstance(obj, np.ndarray) and obj.ndim > 0:
return {
"class_name": "__numpy__",
"config": {
"value": obj.tolist(),
"dtype": backend.standardize_dtype(obj.dtype),
# Treat numpy floats / etc as plain types.
return obj.item()
if isinstance(obj, tf.DType):
if isinstance(obj, types.FunctionType) and obj.__name__ == "<lambda>":
"The object being serialized includes a `lambda`. This is unsafe. "
"In order to reload the object, you will have to pass "
"`safe_mode=False` to the loading function. "
"Please avoid using `lambda` in the "
"future, and use named Python functions instead. "
f"This is the `lambda` being serialized: {inspect.getsource(obj)}",
return {
"class_name": "__lambda__",
"config": {
"value": python_utils.func_dump(obj),
if isinstance(obj, tf.TypeSpec):
ts_config = obj._serialize()
# TensorShape and tf.DType conversion
ts_config = list(
lambda x: x.as_list()
if isinstance(x, tf.TensorShape)
else ( if isinstance(x, tf.DType) else x),
return {
"class_name": "__typespec__",
"spec_name": obj.__class__.__name__,
"module": obj.__class__.__module__,
"config": ts_config,
"registered_name": None,
inner_config = _get_class_or_fn_config(obj)
config_with_public_class = serialize_with_public_class(
obj.__class__, inner_config
if config_with_public_class is not None:
get_build_and_compile_config(obj, config_with_public_class)
record_object_after_serialization(obj, config_with_public_class)
return config_with_public_class
# Any custom object or otherwise non-exported object
if isinstance(obj, types.FunctionType):
module = obj.__module__
module = obj.__class__.__module__
class_name = obj.__class__.__name__
if module == "builtins":
registered_name = None
if isinstance(obj, types.FunctionType):
registered_name = object_registration.get_registered_name(obj)
registered_name = object_registration.get_registered_name(
config = {
"module": module,
"class_name": class_name,
"config": inner_config,
"registered_name": registered_name,
get_build_and_compile_config(obj, config)
record_object_after_serialization(obj, config)
return config
def get_build_and_compile_config(obj, config):
if hasattr(obj, "get_build_config"):
build_config = obj.get_build_config()
if build_config is not None:
config["build_config"] = serialize_dict(build_config)
if hasattr(obj, "get_compile_config"):
compile_config = obj.get_compile_config()
if compile_config is not None:
config["compile_config"] = serialize_dict(compile_config)
def serialize_with_public_class(cls, inner_config=None):
"""Serializes classes from public Keras API or object registration.
Called to check and retrieve the config of any class that has a public
Keras API or has been registered as serializable via
# This gets the `keras_core.*` exported name, such as
# "keras_core.optimizers.Adam".
keras_api_name = api_export.get_name_from_symbol(cls)
# Case of custom or unknown class object
if keras_api_name is None:
registered_name = object_registration.get_registered_name(cls)
if registered_name is None:
return None
# Return custom object config with corresponding registration name
return {
"module": cls.__module__,
"class_name": cls.__name__,
"config": inner_config,
"registered_name": registered_name,
# Split the canonical Keras API name into a Keras module and class name.
parts = keras_api_name.split(".")
return {
"module": ".".join(parts[:-1]),
"class_name": parts[-1],
"config": inner_config,
"registered_name": None,
def serialize_with_public_fn(fn, config, fn_module_name=None):
"""Serializes functions from public Keras API or object registration.
Called to check and retrieve the config of any function that has a public
Keras API or has been registered as serializable via
`keras_core.saving.register_keras_serializable()`. If function's module name
is already known, returns corresponding config.
if fn_module_name:
return {
"module": fn_module_name,
"class_name": "function",
"config": config,
"registered_name": config,
keras_api_name = api_export.get_name_from_symbol(fn)
if keras_api_name:
parts = keras_api_name.split(".")
return {
"module": ".".join(parts[:-1]),
"class_name": "function",
"config": config,
"registered_name": config,
registered_name = object_registration.get_registered_name(fn)
if not registered_name and not fn.__module__ == "builtins":
return None
return {
"module": fn.__module__,
"class_name": "function",
"config": config,
"registered_name": registered_name,
def _get_class_or_fn_config(obj):
"""Return the object's config depending on its type."""
# Functions / lambdas:
if isinstance(obj, types.FunctionType):
return obj.__name__
# All classes:
if hasattr(obj, "get_config"):
config = obj.get_config()
if not isinstance(config, dict):
raise TypeError(
f"The `get_config()` method of {obj} should return "
f"a dict. It returned: {config}"
return serialize_dict(config)
elif hasattr(obj, "__name__"):
return object_registration.get_registered_name(obj)
raise TypeError(
f"Cannot serialize object {obj} of type {type(obj)}. "
"To be serializable, "
"a class must implement the `get_config()` method."
def serialize_dict(obj):
return {key: serialize_keras_object(value) for key, value in obj.items()}
def deserialize_keras_object(
config, custom_objects=None, safe_mode=True, **kwargs
"""Retrieve the object by deserializing the config dict.
The config dict is a Python dictionary that consists of a set of key-value
pairs, and represents a Keras object, such as an `Optimizer`, `Layer`,
`Metrics`, etc. The saving and loading library uses the following keys to
record information of a Keras object:
- `class_name`: String. This is the name of the class,
as exactly defined in the source
code, such as "LossesContainer".
- `config`: Dict. Library-defined or user-defined key-value pairs that store
the configuration of the object, as obtained by `object.get_config()`.
- `module`: String. The path of the python module. Built-in Keras classes
expect to have prefix `keras_core`.
- `registered_name`: String. The key the class is registered under via
`keras_core.saving.register_keras_serializable(package, name)` API. The
key has the format of '{package}>{name}', where `package` and `name` are
the arguments passed to `register_keras_serializable()`. If `name` is not
provided, it uses the class name. If `registered_name` successfully
resolves to a class (that was registered), the `class_name` and `config`
values in the dict will not be used. `registered_name` is only used for
non-built-in classes.
For example, the following dictionary represents the built-in Adam optimizer
with the relevant config:
dict_structure = {
"class_name": "Adam",
"config": {
"amsgrad": false,
"beta_1": 0.8999999761581421,
"beta_2": 0.9990000128746033,
"decay": 0.0,
"epsilon": 1e-07,
"learning_rate": 0.0010000000474974513,
"name": "Adam"
"module": "keras_core.optimizers",
"registered_name": None
# Returns an `Adam` instance identical to the original one.
If the class does not have an exported Keras namespace, the library tracks
it by its `module` and `class_name`. For example:
dict_structure = {
"class_name": "MetricsList",
"config": {
"module": "keras_core.trainers.compile_utils",
"registered_name": "MetricsList"
# Returns a `MetricsList` instance identical to the original one.
And the following dictionary represents a user-customized `MeanSquaredError`
class ModifiedMeanSquaredError(keras_core.losses.MeanSquaredError):
dict_structure = {
"class_name": "ModifiedMeanSquaredError",
"config": {
"fn": "mean_squared_error",
"name": "mean_squared_error",
"reduction": "auto"
"registered_name": "my_package>ModifiedMeanSquaredError"
# Returns the `ModifiedMeanSquaredError` object
config: Python dict describing the object.
custom_objects: Python dict containing a mapping between custom
object names the corresponding classes or functions.
safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization.
When `safe_mode=False`, loading an object has the potential to
trigger arbitrary code execution. This argument is only
applicable to the Keras v3 model format. Defaults to `True`.
The object described by the `config` dictionary.
safe_scope_arg = in_safe_mode() # Enforces SafeModeScope
safe_mode = safe_scope_arg if safe_scope_arg is not None else safe_mode
module_objects = kwargs.pop("module_objects", None)
custom_objects = custom_objects or {}
tlco = global_state.get_global_attribute("custom_objects_scope_dict", {})
gco = object_registration.GLOBAL_CUSTOM_OBJECTS
custom_objects = {**custom_objects, **tlco, **gco}
if config is None:
return None
if (
isinstance(config, str)
and custom_objects
and custom_objects.get(config) is not None
# This is to deserialize plain functions which are serialized as
# string names by legacy saving formats.
return custom_objects[config]
if isinstance(config, (list, tuple)):
return [
x, custom_objects=custom_objects, safe_mode=safe_mode
for x in config
if module_objects is not None:
inner_config, fn_module_name, has_custom_object = None, None, False
if isinstance(config, dict):
if "config" in config:
inner_config = config["config"]
if "class_name" not in config:
raise ValueError(
f"Unknown `config` as a `dict`, config={config}"
# Check case where config is function or class and in custom objects
if custom_objects and (
config["class_name"] in custom_objects
or config.get("registered_name") in custom_objects
or (
isinstance(inner_config, str)
and inner_config in custom_objects
has_custom_object = True
# Case where config is function but not in custom objects
elif config["class_name"] == "function":
fn_module_name = config["module"]
if fn_module_name == "builtins":
config = config["config"]
config = config["registered_name"]
# Case where config is class but not in custom objects
if config.get("module", "_") is None:
raise TypeError(
"Cannot deserialize object of type "
f"`{config['class_name']}`. If "
f"`{config['class_name']}` is a custom class, please "
"register it using the "
"`@keras_core.saving.register_keras_serializable()` "
config = config["class_name"]
if not has_custom_object:
# Return if not found in either module objects or custom objects
if config not in module_objects:
# Object has already been deserialized
return config
if isinstance(module_objects[config], types.FunctionType):
return deserialize_keras_object(
module_objects[config], config, fn_module_name
return deserialize_keras_object(
module_objects[config], inner_config=inner_config
if isinstance(config, PLAIN_TYPES):
return config
if not isinstance(config, dict):
raise TypeError(f"Could not parse config: {config}")
if "class_name" not in config or "config" not in config:
return {
key: deserialize_keras_object(
value, custom_objects=custom_objects, safe_mode=safe_mode
for key, value in config.items()
class_name = config["class_name"]
inner_config = config["config"] or {}
custom_objects = custom_objects or {}
# Special cases:
if class_name == "__keras_tensor__":
obj = backend.KerasTensor(
inner_config["shape"], dtype=inner_config["dtype"]
obj._pre_serialization_keras_history = inner_config["keras_history"]
return obj
if class_name == "__tensor__":
return backend.convert_to_tensor(
inner_config["value"], dtype=inner_config["dtype"]
if class_name == "__numpy__":
return np.array(inner_config["value"], dtype=inner_config["dtype"])
if config["class_name"] == "__bytes__":
return inner_config["value"].encode("utf-8")
if config["class_name"] == "__lambda__":
if safe_mode:
raise ValueError(
"Requested the deserialization of a `lambda` object. "
"This carries a potential risk of arbitrary code execution "
"and thus it is disallowed by default. If you trust the "
"source of the saved model, you can pass `safe_mode=False` to "
"the loading function in order to allow `lambda` loading, "
"or call `keras_core.config.enable_unsafe_deserialization()`."
return python_utils.func_load(inner_config["value"])
if config["class_name"] == "__typespec__":
obj = _retrieve_class_or_fn(
# Conversion to TensorShape and tf.DType
inner_config = map(
lambda x: tf.TensorShape(x)
if isinstance(x, list)
else (getattr(tf, x) if hasattr(tf.dtypes, str(x)) else x),
return obj._deserialize(tuple(inner_config))
# Below: classes and functions.
module = config.get("module", None)
registered_name = config.get("registered_name", class_name)
if class_name == "function":
fn_name = inner_config
return _retrieve_class_or_fn(
# Below, handling of all classes.
# First, is it a shared object?
if "shared_object_id" in config:
obj = get_shared_object(config["shared_object_id"])
if obj is not None:
return obj
cls = _retrieve_class_or_fn(
if isinstance(cls, types.FunctionType):
return cls
if not hasattr(cls, "from_config"):
raise TypeError(
f"Unable to reconstruct an instance of '{class_name}' because "
f"the class is missing a `from_config()` method. "
f"Full object config: {config}"
# Instantiate the class from its config inside a custom object scope
# so that we can catch any custom objects that the config refers to.
custom_obj_scope = object_registration.CustomObjectScope(custom_objects)
safe_mode_scope = SafeModeScope(safe_mode)
with custom_obj_scope, safe_mode_scope:
instance = cls.from_config(inner_config)
build_config = config.get("build_config", None)
if build_config and not instance.built:
instance.built = True
compile_config = config.get("compile_config", None)
if compile_config:
instance.compiled = True
if "shared_object_id" in config:
instance, config["shared_object_id"]
return instance
def _retrieve_class_or_fn(
name, registered_name, module, obj_type, full_config, custom_objects=None
# If there is a custom object registered via
# `register_keras_serializable()`, that takes precedence.
if obj_type == "function":
custom_obj = object_registration.get_registered_object(
name, custom_objects=custom_objects
custom_obj = object_registration.get_registered_object(
registered_name, custom_objects=custom_objects
if custom_obj is not None:
return custom_obj
if module:
# If it's a Keras built-in object,
# we cannot always use direct import, because the exported
# module name might not match the package structure
# (e.g. experimental symbols).
if module == "keras_core" or module.startswith("keras_core."):
api_name = module + "." + name
obj = api_export.get_symbol_from_name(api_name)
if obj is not None:
return obj
# Configs of Keras built-in functions do not contain identifying
# information other than their name (e.g. 'acc' or 'tanh'). This special
# case searches the Keras modules that contain built-ins to retrieve
# the corresponding function from the identifying string.
if obj_type == "function" and module == "builtins":
obj = api_export.get_symbol_from_name(
"keras_core." + mod + "." + name
if obj is not None:
return obj
# Retrieval of registered custom function in a package
filtered_dict = {
k: v
for k, v in custom_objects.items()
if k.endswith(full_config["config"])
if filtered_dict:
return next(iter(filtered_dict.values()))
# Otherwise, attempt to retrieve the class object given the `module`
# and `class_name`. Import the module, find the class.
mod = importlib.import_module(module)
except ModuleNotFoundError:
raise TypeError(
f"Could not deserialize {obj_type} '{name}' because "
f"its parent module {module} cannot be imported. "
f"Full object config: {full_config}"
obj = vars(mod).get(name, None)
# Special case for keras.metrics.metrics
if obj is None and registered_name is not None:
obj = vars(mod).get(registered_name, None)
if obj is not None:
return obj
raise TypeError(
f"Could not locate {obj_type} '{name}'. "
"Make sure custom classes are decorated with "
"`@keras_core.saving.register_keras_serializable()`. "
f"Full object config: {full_config}"