From 7fc2f3320ebd6be75995253d18ea782f2f4f6bae Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 24 Apr 2023 14:58:38 -0700 Subject: [PATCH] Finally make functional subclassing work --- keras_core/models/functional.py | 57 ++++++++++++++++++--- keras_core/models/model.py | 33 +++++++++--- keras_core/saving/serialization_lib_test.py | 44 ++++++++-------- 3 files changed, 100 insertions(+), 34 deletions(-) diff --git a/keras_core/models/functional.py b/keras_core/models/functional.py index 6661c73a9..a5e66eb6c 100644 --- a/keras_core/models/functional.py +++ b/keras_core/models/functional.py @@ -11,7 +11,6 @@ from keras_core.models.model import Model from keras_core.operations.function import Function from keras_core.operations.function import make_node_key from keras_core.saving import serialization_lib -from keras_core.utils import python_utils from keras_core.utils import tracking @@ -29,12 +28,7 @@ class Functional(Function, Model): Symbolic add_loss """ - def __new__(cls, *args, **kwargs): - # Skip Model.__new__. - return Function.__new__(cls) - @tracking.no_automatic_dependency_tracking - @python_utils.default def __init__(self, inputs, outputs, name=None, **kwargs): if isinstance(inputs, dict): for k, v in inputs.items(): @@ -192,7 +186,7 @@ class Functional(Function, Model): # Subclassed networks are not serializable # (unless serialization is implemented by # the author of the subclassed network). - return Model.get_config() + return Model.get_config(self) config = { "name": self.name, @@ -266,6 +260,55 @@ class Functional(Function, Model): @classmethod def from_config(cls, config, custom_objects=None): + functional_config_keys = [ + "name", + "layers", + "input_layers", + "output_layers", + ] + is_functional_config = all( + key in config for key in functional_config_keys + ) + argspec = inspect.getfullargspec(cls.__init__) + functional_init_args = inspect.getfullargspec(Functional.__init__).args[ + 1: + ] + revivable_as_functional = ( + cls in {Functional, Model} + or argspec.args[1:] == functional_init_args + or (argspec.varargs == "args" and argspec.varkw == "kwargs") + ) + if is_functional_config and revivable_as_functional: + # Revive Functional model + # (but not Functional subclasses with a custom __init__) + return cls._from_config(config, custom_objects=custom_objects) + + # Either the model has a custom __init__, or the config + # does not contain all the information necessary to + # revive a Functional model. This happens when the user creates + # subclassed models where `get_config()` is returning + # insufficient information to be considered a Functional model. + # In this case, we fall back to provide all config into the + # constructor of the class. + try: + return cls(**config) + except TypeError as e: + raise TypeError( + "Unable to revive model from config. When overriding " + "the `get_config()` method, make sure that the " + "returned config contains all items used as arguments " + f"in the constructor to {cls}, " + "which is the default behavior. " + "You can override this default behavior by defining a " + "`from_config(cls, config)` class method to specify " + "how to create an " + f"instance of {cls.__name__} from its config.\n\n" + f"Received config={config}\n\n" + f"Error encountered during deserialization: {e}" + ) + + @classmethod + def _from_config(cls, config, custom_objects=None): """Instantiates a Model from its config (output of `get_config()`).""" # Layer instances created during # the graph reconstruction process diff --git a/keras_core/models/model.py b/keras_core/models/model.py index 8a2880396..886525365 100644 --- a/keras_core/models/model.py +++ b/keras_core/models/model.py @@ -1,7 +1,6 @@ from keras_core import backend from keras_core.api_export import keras_core_export from keras_core.layers.layer import Layer -from keras_core.utils import python_utils from keras_core.utils import summary_utils if backend.backend() == "tensorflow": @@ -32,8 +31,8 @@ class Model(Trainer, Layer): """ def __new__(cls, *args, **kwargs): - # Signature detection - if functional_init_arguments(args, kwargs): + # Signature detection for usage of `Model` as a `Functional` + if functional_init_arguments(args, kwargs) and cls == Model: from keras_core.models import functional return functional.Functional(*args, **kwargs) @@ -43,9 +42,10 @@ class Model(Trainer, Layer): Trainer.__init__(self) from keras_core.models import functional - if isinstance(self, functional.Functional) and python_utils.is_default( - self.__init__ - ): + # Signature detection for usage of a `Model` subclass + # as a `Functional` subclass + if functional_init_arguments(args, kwargs): + inject_functional_model_class(self.__class__) functional.Functional.__init__(self, *args, **kwargs) else: Layer.__init__(self, *args, **kwargs) @@ -166,3 +166,24 @@ def functional_init_arguments(args, kwargs): or (len(args) == 1 and "outputs" in kwargs) or ("inputs" in kwargs and "outputs" in kwargs) ) + + +def inject_functional_model_class(cls): + """Inject `Functional` into the hierarchy of this class if needed.""" + from keras_core.models import functional + + if cls == Model: + return functional.Functional + # In case there is any multiple inheritance, we stop injecting the + # class if keras model is not in its class hierarchy. + if cls == object: + return object + + cls.__bases__ = tuple( + inject_functional_model_class(base) for base in cls.__bases__ + ) + # Trigger any `__new__` class swapping that needed to happen on `Functional` + # but did not because functional was not in the class hierarchy. + cls.__new__(cls) + + return cls diff --git a/keras_core/saving/serialization_lib_test.py b/keras_core/saving/serialization_lib_test.py index cc5af6165..02fe105b9 100644 --- a/keras_core/saving/serialization_lib_test.py +++ b/keras_core/saving/serialization_lib_test.py @@ -210,29 +210,31 @@ class SerializationLibTest(testing.TestCase): new_model.set_weights(model.get_weights()) y2 = new_model(x) self.assertAllClose(y1, y2, atol=1e-5) - # TODO - # self.assertIsInstance(new_model, PlainFunctionalSubclass) + self.assertIsInstance(new_model, PlainFunctionalSubclass) - # TODO - # class FunctionalSubclassWCustomInit(keras_core.Model): - # def __init__(self, num_units=1, **kwargs): - # inputs = keras_core.Input((2,), batch_size=3) - # outputs = keras_core.layers.Dense(num_units)(inputs) - # super().__init__(inputs, outputs) + class FunctionalSubclassWCustomInit(keras_core.Model): + def __init__(self, num_units=2): + inputs = keras_core.Input((2,), batch_size=3) + outputs = keras_core.layers.Dense(num_units)(inputs) + super().__init__(inputs, outputs) + self.num_units = num_units - # model = FunctionalSubclassWCustomInit(num_units=2) - # x = ops.random.normal((2, 2)) - # y1 = model(x) - # _, new_model, _ = self.roundtrip( - # model, - # custom_objects={ - # "FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit - # }, - # ) - # new_model.set_weights(model.get_weights()) - # y2 = new_model(x) - # self.assertAllClose(y1, y2, atol=1e-5) - # self.assertIsInstance(new_model, FunctionalSubclassWCustomInit) + def get_config(self): + return {"num_units": self.num_units} + + model = FunctionalSubclassWCustomInit(num_units=3) + x = ops.random.normal((2, 2)) + y1 = model(x) + _, new_model, _ = self.roundtrip( + model, + custom_objects={ + "FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit + }, + ) + new_model.set_weights(model.get_weights()) + y2 = new_model(x) + self.assertAllClose(y1, y2, atol=1e-5) + self.assertIsInstance(new_model, FunctionalSubclassWCustomInit) def test_shared_object(self): class MyLayer(keras_core.layers.Layer):