Some progress on saving.
This commit is contained in:
parent
6034134d95
commit
b9753e1a46
@ -30,7 +30,7 @@ class KerasVariable:
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<KerasVariable shape={self.shape}, dtype={self.dtype}, "
|
||||
"name={self.name}>"
|
||||
f"name={self.name}>"
|
||||
)
|
||||
|
||||
|
||||
|
@ -59,6 +59,13 @@ class Variable(KerasVariable):
|
||||
def __init__(self, value, dtype=None, trainable=True, name=None):
|
||||
self.name = name or auto_name(self.__class__.__name__)
|
||||
dtype = standardize_dtype(dtype)
|
||||
if in_stateless_scope():
|
||||
raise ValueError(
|
||||
"You are attempting to create a variable "
|
||||
"while in a stateless scope. This is disallowed. "
|
||||
"Make sure that all variables are created "
|
||||
"before you start using your layer/model objects. "
|
||||
)
|
||||
self._value = jnp.array(value, dtype=dtype)
|
||||
self._dtype = dtype
|
||||
self._shape = tuple(self._value.shape)
|
||||
@ -299,3 +306,18 @@ def execute(op_name, *args, **kwargs):
|
||||
op = getattr(jnp, op_name)
|
||||
return op(*args, **kwargs)
|
||||
raise AttributeError(f"The JAX backend does not support op '{op_name}'")
|
||||
|
||||
|
||||
def traceable_tensor(shape, dtype=None):
|
||||
"""Create a "traceable tensor".
|
||||
|
||||
That's a tensor that can be passed as input
|
||||
to a stateful backend-native function to
|
||||
create state during the trace.
|
||||
"""
|
||||
shape = list(shape)
|
||||
dtype = dtype or "float32"
|
||||
for i, x in enumerate(shape):
|
||||
if x is None:
|
||||
shape[i] = 1
|
||||
return jnp.ones(shape, dtype=dtype)
|
@ -274,3 +274,18 @@ def execute(op_name, *args, **kwargs):
|
||||
raise AttributeError(
|
||||
f"The TensorFlow backend does not support op '{op_name}'"
|
||||
)
|
||||
|
||||
|
||||
def traceable_tensor(shape, dtype=None):
|
||||
"""Create a "traceable tensor".
|
||||
|
||||
That's a tensor that can be passed as input
|
||||
to a stateful backend-native function to
|
||||
create state during the trace.
|
||||
"""
|
||||
shape = list(shape)
|
||||
dtype = dtype or "float32"
|
||||
for i, x in enumerate(shape):
|
||||
if x is None:
|
||||
shape[i] = 1
|
||||
return tf.ones(shape, dtype=dtype)
|
||||
|
@ -221,19 +221,19 @@ class Model(Trainer, Layer):
|
||||
input_shape = tuple(input_shape)
|
||||
if isinstance(input_shape, list):
|
||||
input_tensors = [
|
||||
backend.KerasTensor(shape) for shape in input_shape
|
||||
backend.traceable_tensor(shape) for shape in input_shape
|
||||
]
|
||||
elif isinstance(input_shape, dict):
|
||||
input_tensors = {
|
||||
k: backend.KerasTensor(shape)
|
||||
k: backend.traceable_tensor(shape)
|
||||
for k, shape in input_shape.items()
|
||||
}
|
||||
else:
|
||||
input_tensors = backend.KerasTensor(input_shape)
|
||||
input_tensors = backend.traceable_tensor(input_shape)
|
||||
try:
|
||||
self(input_tensors)
|
||||
self._build_shapes_dict = config
|
||||
except:
|
||||
except Exception as e:
|
||||
failure = True
|
||||
elif "shapes_dict" in config:
|
||||
# Case: inputs were recorded as multiple keyword arguments.
|
||||
@ -242,7 +242,7 @@ class Model(Trainer, Layer):
|
||||
):
|
||||
# Case: all input keyword arguments were plain tensors.
|
||||
input_tensors = {
|
||||
k: backend.KerasTensor(v)
|
||||
k: backend.traceable_tensor(v)
|
||||
for k, v in config["shapes_dict"].items()
|
||||
}
|
||||
try:
|
||||
|
@ -42,7 +42,6 @@ def deserialize(config, custom_objects=None):
|
||||
if config["class_name"].lower() in ALL_OBJECTS_DICT:
|
||||
config["class_name"] = config["class_name"].lower()
|
||||
|
||||
print("deserialize:", config)
|
||||
return serialization_lib.deserialize_keras_object(
|
||||
config,
|
||||
module_objects=ALL_OBJECTS_DICT,
|
||||
@ -63,16 +62,13 @@ def get(identifier):
|
||||
Returns:
|
||||
A Keras Optimizer instance.
|
||||
"""
|
||||
print("call get with", identifier)
|
||||
if isinstance(identifier, Optimizer):
|
||||
return identifier
|
||||
elif isinstance(identifier, dict):
|
||||
return deserialize(identifier)
|
||||
elif isinstance(identifier, str):
|
||||
config = {"class_name": identifier, "config": {}}
|
||||
opt = deserialize(config)
|
||||
print(opt)
|
||||
return opt
|
||||
return deserialize(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not interpret optimizer identifier: {identifier}"
|
||||
|
Loading…
Reference in New Issue
Block a user