diff --git a/keras_core/layers/activations/activation_test.py b/keras_core/layers/activations/activation_test.py index 9f828ec20..49be3cf35 100644 --- a/keras_core/layers/activations/activation_test.py +++ b/keras_core/layers/activations/activation_test.py @@ -4,7 +4,7 @@ from keras_core import testing class ActivationTest(testing.TestCase): - def test_dense_basics(self): + def test_activation_basics(self): self.run_layer_test( layers.Activation, init_kwargs={ diff --git a/keras_core/models/functional.py b/keras_core/models/functional.py index 8fbb42384..847610368 100644 --- a/keras_core/models/functional.py +++ b/keras_core/models/functional.py @@ -24,15 +24,16 @@ class Functional(Function, Model): Symbolic add_loss """ + def __new__(cls, *args, **kwargs): + # Skip Model.__new__. + return Function.__new__(cls) + @tracking.no_automatic_dependency_tracking def __init__(self, inputs, outputs, name=None, **kwargs): # This is used by the Model class, since we have some logic to swap the # class in the __new__ method, which will lead to __init__ get invoked # twice. Using the skip_init to skip one of the invocation of __init__ # to avoid any side effects - skip_init = kwargs.pop("skip_init", False) - if skip_init: - return if isinstance(inputs, dict): for k, v in inputs.items(): if not isinstance(v, backend.KerasTensor): diff --git a/keras_core/models/model.py b/keras_core/models/model.py index d3de3c092..03e307b19 100644 --- a/keras_core/models/model.py +++ b/keras_core/models/model.py @@ -1,6 +1,9 @@ +import inspect + from keras_core import backend from keras_core.api_export import keras_core_export from keras_core.layers.layer import Layer +from keras_core.utils import python_utils from keras_core.utils import summary_utils if backend.backend() == "tensorflow": @@ -32,11 +35,11 @@ class Model(Trainer, Layer): def __new__(cls, *args, **kwargs): # Signature detection - if functional_init_arguments(args, kwargs) and cls == Model: + if functional_init_arguments(args, kwargs): # Functional model from keras_core.models import functional - return functional.Functional(*args, **kwargs, skip_init=True) + return functional.Functional(*args, **kwargs) return Layer.__new__(cls) def __init__(self, trainable=True, name=None, dtype=None): @@ -152,6 +155,32 @@ class Model(Trainer, Layer): def export(self, filepath): raise NotImplementedError + @python_utils.default + def get_config(self): + # Prepare base arguments + config = { + "name": self.name, + "trainable": self.trainable, + } + # Check whether the class has a constructor compatible with a Functional + # model or if it has a custom constructor. + if functional_like_constructor(self.__class__): + # Only return a Functional config if the constructor is the same + # as that of a Functional model. This excludes subclassed Functional + # models with a custom __init__. + config = {**config, **get_functional_config(self)} + else: + # Try to autogenerate config + xtra_args = set(config.keys()) + if getattr(self, "_auto_get_config", False): + config.update(self._auto_config.config) + # Remove args non explicitly supported + argspec = inspect.getfullargspec(self.__init__) + if argspec.varkw != "kwargs": + for key in xtra_args - xtra_args.intersection(argspec.args[1:]): + config.pop(key, None) + return config + def functional_init_arguments(args, kwargs): return ( @@ -159,3 +188,11 @@ def functional_init_arguments(args, kwargs): or (len(args) == 1 and "outputs" in kwargs) or ("inputs" in kwargs and "outputs" in kwargs) ) + + +def functional_like_constructor(cls): + raise NotImplementedError + + +def get_functional_config(model): + raise NotImplementedError diff --git a/keras_core/operations/operation.py b/keras_core/operations/operation.py index d4690ca3e..3e72612f9 100644 --- a/keras_core/operations/operation.py +++ b/keras_core/operations/operation.py @@ -1,12 +1,26 @@ +import inspect +import textwrap + +from tensorflow import nest + from keras_core import backend from keras_core.backend.keras_tensor import any_symbolic_tensors from keras_core.operations.node import Node +from keras_core.saving import serialization_lib +from keras_core.utils import python_utils from keras_core.utils.naming import auto_name class Operation: def __init__(self, name=None): - self.name = name or auto_name(self.__class__.__name__) + if name is None: + name = auto_name(self.__class__.__name__) + if not isinstance(name, str): + raise ValueError( + "Argument `name` should be a string. " + f"Received instead: name={name} (of type {type(name)})" + ) + self.name = name self._inbound_nodes = [] self._outbound_nodes = [] @@ -44,12 +58,123 @@ class Operation: f"Error encountered: {e}" ) + def __new__(cls, *args, **kwargs): + """We override __new__ to saving serializable constructor arguments. + + These arguments are used to auto-generate an object serialization config, + which enables user-created subclasses to be serializable out of the box + in most cases without forcing the user to manually implement + `get_config()`. + """ + # Generate a config to be returned by default by `get_config()`. + arg_names = inspect.getfullargspec(cls.__init__).args + kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) + instance = super(Operation, cls).__new__(cls) + # For safety, we only rely on auto-configs for a small set of + # serializable types. + supported_types = (str, int, float, bool, type(None)) + try: + flat_arg_values = nest.flatten(kwargs) + auto_config = True + for value in flat_arg_values: + if not isinstance(value, supported_types): + auto_config = False + break + except TypeError: + auto_config = False + try: + if auto_config: + instance._auto_config = serialization_lib.SerializableDict( + **kwargs + ) + else: + instance._auto_config = None + except RecursionError: + # Setting an instance attribute in __new__ has the potential + # to trigger an infinite recursion if a subclass overrides + # setattr in an unsafe way. + pass + return instance + + @python_utils.default def get_config(self): - return {"name": self.name} + """Returns the config of the object. + + An object config is a Python dictionary (serializable) + containing the information needed to re-insstantiate it. + + Returns: + Python dictionary. + """ + config = { + "name": self.name, + } + + if not python_utils.is_default(self.get_config): + # In this case the subclass implements get_config() + return config + + # In this case the subclass doesn't implement get_config(): + # Let's see if we can autogenerate it. + if getattr(self, "_auto_config", None) is not None: + xtra_args = set(config.keys()) + config.update(self._auto_config.config) + # Remove args non explicitly supported + argspec = inspect.getfullargspec(self.__init__) + if argspec.varkw != "kwargs": + for key in xtra_args - xtra_args.intersection(argspec.args[1:]): + config.pop(key, None) + return config + else: + raise NotImplementedError( + textwrap.dedent( + f""" + Object {self.__class__.__name__} was created by passing + non-serializable argument values in `__init__()`, + and therefore the object must override `get_config()` in + order to be serializable. Please implement `get_config()`. + + Example: + + class CustomLayer(keras.layers.Layer): + def __init__(self, arg1, arg2, **kwargs): + super().__init__(**kwargs) + self.arg1 = arg1 + self.arg2 = arg2 + + def get_config(self): + config = super().get_config() + config.update({ + "arg1": self.arg1, + "arg2": self.arg2, + }) + return config""" + ) + ) @classmethod def from_config(cls, config): - return cls(**config) + """Creates a layer from its config. + + This method is the reverse of `get_config`, + capable of instantiating the same layer from the config + dictionary. It does not handle layer connectivity + (handled by Network), nor weights (handled by `set_weights`). + + Args: + config: A Python dictionary, typically the + output of get_config. + + Returns: + A layer instance. + """ + try: + return cls(**config) + except Exception as e: + raise TypeError( + f"Error when deserializing class '{cls.__name__}' using " + f"config={config}.\n\nException encountered: {e}" + ) def __repr__(self): return f"" diff --git a/keras_core/operations/operation_test.py b/keras_core/operations/operation_test.py index e025995ee..94a9fd6b1 100644 --- a/keras_core/operations/operation_test.py +++ b/keras_core/operations/operation_test.py @@ -26,6 +26,21 @@ class OpWithMultipleOutputs(operation.Operation): ) +class OpWithCustomConstructor(operation.Operation): + def __init__(self, alpha, mode="foo"): + super().__init__() + self.alpha = alpha + self.mode = mode + + def call(self, x): + if self.mode == "foo": + return x + return self.alpha * x + + def compute_output_spec(self, x): + return keras_tensor.KerasTensor(x.shape, x.dtype) + + class OperationTest(testing.TestCase): def test_symbolic_call(self): x = keras_tensor.KerasTensor(shape=(2, 3), name="x") @@ -118,6 +133,13 @@ class OperationTest(testing.TestCase): op = OpWithMultipleOutputs.from_config(config) self.assertEqual(op.name, "test_op") + def test_autoconfig(self): + op = OpWithCustomConstructor(alpha=0.2, mode="bar") + config = op.get_config() + self.assertEqual(config, {"alpha": 0.2, "mode": "bar"}) + revived = OpWithCustomConstructor.from_config(config) + self.assertEqual(revived.get_config(), config) + def test_input_conversion(self): x = np.ones((2,)) y = np.ones((2,))