From b9753e1a467ddc59ba92698ea21c8927f27fae7d Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Tue, 25 Apr 2023 15:03:16 -0700 Subject: [PATCH] Some progress on saving. --- keras_core/backend/common/variables.py | 2 +- keras_core/backend/jax/__init__.py | 22 ++++++++++++++++++++++ keras_core/backend/tensorflow/__init__.py | 15 +++++++++++++++ keras_core/models/model.py | 10 +++++----- keras_core/optimizers/__init__.py | 6 +----- 5 files changed, 44 insertions(+), 11 deletions(-) diff --git a/keras_core/backend/common/variables.py b/keras_core/backend/common/variables.py index cdd81bba6..f44e18949 100644 --- a/keras_core/backend/common/variables.py +++ b/keras_core/backend/common/variables.py @@ -30,7 +30,7 @@ class KerasVariable: def __repr__(self): return ( f"" + f"name={self.name}>" ) diff --git a/keras_core/backend/jax/__init__.py b/keras_core/backend/jax/__init__.py index 6c6c36099..d410c25af 100644 --- a/keras_core/backend/jax/__init__.py +++ b/keras_core/backend/jax/__init__.py @@ -59,6 +59,13 @@ class Variable(KerasVariable): def __init__(self, value, dtype=None, trainable=True, name=None): self.name = name or auto_name(self.__class__.__name__) dtype = standardize_dtype(dtype) + if in_stateless_scope(): + raise ValueError( + "You are attempting to create a variable " + "while in a stateless scope. This is disallowed. " + "Make sure that all variables are created " + "before you start using your layer/model objects. " + ) self._value = jnp.array(value, dtype=dtype) self._dtype = dtype self._shape = tuple(self._value.shape) @@ -299,3 +306,18 @@ def execute(op_name, *args, **kwargs): op = getattr(jnp, op_name) return op(*args, **kwargs) raise AttributeError(f"The JAX backend does not support op '{op_name}'") + + +def traceable_tensor(shape, dtype=None): + """Create a "traceable tensor". + + That's a tensor that can be passed as input + to a stateful backend-native function to + create state during the trace. + """ + shape = list(shape) + dtype = dtype or "float32" + for i, x in enumerate(shape): + if x is None: + shape[i] = 1 + return jnp.ones(shape, dtype=dtype) \ No newline at end of file diff --git a/keras_core/backend/tensorflow/__init__.py b/keras_core/backend/tensorflow/__init__.py index 0dd6a4db8..bc3aeb916 100644 --- a/keras_core/backend/tensorflow/__init__.py +++ b/keras_core/backend/tensorflow/__init__.py @@ -274,3 +274,18 @@ def execute(op_name, *args, **kwargs): raise AttributeError( f"The TensorFlow backend does not support op '{op_name}'" ) + + +def traceable_tensor(shape, dtype=None): + """Create a "traceable tensor". + + That's a tensor that can be passed as input + to a stateful backend-native function to + create state during the trace. + """ + shape = list(shape) + dtype = dtype or "float32" + for i, x in enumerate(shape): + if x is None: + shape[i] = 1 + return tf.ones(shape, dtype=dtype) diff --git a/keras_core/models/model.py b/keras_core/models/model.py index 56ff5767c..301ccc068 100644 --- a/keras_core/models/model.py +++ b/keras_core/models/model.py @@ -221,19 +221,19 @@ class Model(Trainer, Layer): input_shape = tuple(input_shape) if isinstance(input_shape, list): input_tensors = [ - backend.KerasTensor(shape) for shape in input_shape + backend.traceable_tensor(shape) for shape in input_shape ] elif isinstance(input_shape, dict): input_tensors = { - k: backend.KerasTensor(shape) + k: backend.traceable_tensor(shape) for k, shape in input_shape.items() } else: - input_tensors = backend.KerasTensor(input_shape) + input_tensors = backend.traceable_tensor(input_shape) try: self(input_tensors) self._build_shapes_dict = config - except: + except Exception as e: failure = True elif "shapes_dict" in config: # Case: inputs were recorded as multiple keyword arguments. @@ -242,7 +242,7 @@ class Model(Trainer, Layer): ): # Case: all input keyword arguments were plain tensors. input_tensors = { - k: backend.KerasTensor(v) + k: backend.traceable_tensor(v) for k, v in config["shapes_dict"].items() } try: diff --git a/keras_core/optimizers/__init__.py b/keras_core/optimizers/__init__.py index 0064074ff..82f519859 100644 --- a/keras_core/optimizers/__init__.py +++ b/keras_core/optimizers/__init__.py @@ -42,7 +42,6 @@ def deserialize(config, custom_objects=None): if config["class_name"].lower() in ALL_OBJECTS_DICT: config["class_name"] = config["class_name"].lower() - print("deserialize:", config) return serialization_lib.deserialize_keras_object( config, module_objects=ALL_OBJECTS_DICT, @@ -63,16 +62,13 @@ def get(identifier): Returns: A Keras Optimizer instance. """ - print("call get with", identifier) if isinstance(identifier, Optimizer): return identifier elif isinstance(identifier, dict): return deserialize(identifier) elif isinstance(identifier, str): config = {"class_name": identifier, "config": {}} - opt = deserialize(config) - print(opt) - return opt + return deserialize(config) else: raise ValueError( f"Could not interpret optimizer identifier: {identifier}"