Some progress on saving.

This commit is contained in:
Francois Chollet 2023-04-25 15:03:16 -07:00
parent 6034134d95
commit b9753e1a46
5 changed files with 44 additions and 11 deletions

@ -30,7 +30,7 @@ class KerasVariable:
def __repr__(self):
return (
f"<KerasVariable shape={self.shape}, dtype={self.dtype}, "
"name={self.name}>"
f"name={self.name}>"
)

@ -59,6 +59,13 @@ class Variable(KerasVariable):
def __init__(self, value, dtype=None, trainable=True, name=None):
self.name = name or auto_name(self.__class__.__name__)
dtype = standardize_dtype(dtype)
if in_stateless_scope():
raise ValueError(
"You are attempting to create a variable "
"while in a stateless scope. This is disallowed. "
"Make sure that all variables are created "
"before you start using your layer/model objects. "
)
self._value = jnp.array(value, dtype=dtype)
self._dtype = dtype
self._shape = tuple(self._value.shape)
@ -299,3 +306,18 @@ def execute(op_name, *args, **kwargs):
op = getattr(jnp, op_name)
return op(*args, **kwargs)
raise AttributeError(f"The JAX backend does not support op '{op_name}'")
def traceable_tensor(shape, dtype=None):
"""Create a "traceable tensor".
That's a tensor that can be passed as input
to a stateful backend-native function to
create state during the trace.
"""
shape = list(shape)
dtype = dtype or "float32"
for i, x in enumerate(shape):
if x is None:
shape[i] = 1
return jnp.ones(shape, dtype=dtype)

@ -274,3 +274,18 @@ def execute(op_name, *args, **kwargs):
raise AttributeError(
f"The TensorFlow backend does not support op '{op_name}'"
)
def traceable_tensor(shape, dtype=None):
"""Create a "traceable tensor".
That's a tensor that can be passed as input
to a stateful backend-native function to
create state during the trace.
"""
shape = list(shape)
dtype = dtype or "float32"
for i, x in enumerate(shape):
if x is None:
shape[i] = 1
return tf.ones(shape, dtype=dtype)

@ -221,19 +221,19 @@ class Model(Trainer, Layer):
input_shape = tuple(input_shape)
if isinstance(input_shape, list):
input_tensors = [
backend.KerasTensor(shape) for shape in input_shape
backend.traceable_tensor(shape) for shape in input_shape
]
elif isinstance(input_shape, dict):
input_tensors = {
k: backend.KerasTensor(shape)
k: backend.traceable_tensor(shape)
for k, shape in input_shape.items()
}
else:
input_tensors = backend.KerasTensor(input_shape)
input_tensors = backend.traceable_tensor(input_shape)
try:
self(input_tensors)
self._build_shapes_dict = config
except:
except Exception as e:
failure = True
elif "shapes_dict" in config:
# Case: inputs were recorded as multiple keyword arguments.
@ -242,7 +242,7 @@ class Model(Trainer, Layer):
):
# Case: all input keyword arguments were plain tensors.
input_tensors = {
k: backend.KerasTensor(v)
k: backend.traceable_tensor(v)
for k, v in config["shapes_dict"].items()
}
try:

@ -42,7 +42,6 @@ def deserialize(config, custom_objects=None):
if config["class_name"].lower() in ALL_OBJECTS_DICT:
config["class_name"] = config["class_name"].lower()
print("deserialize:", config)
return serialization_lib.deserialize_keras_object(
config,
module_objects=ALL_OBJECTS_DICT,
@ -63,16 +62,13 @@ def get(identifier):
Returns:
A Keras Optimizer instance.
"""
print("call get with", identifier)
if isinstance(identifier, Optimizer):
return identifier
elif isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, str):
config = {"class_name": identifier, "config": {}}
opt = deserialize(config)
print(opt)
return opt
return deserialize(config)
else:
raise ValueError(
f"Could not interpret optimizer identifier: {identifier}"