Finally make functional subclassing work

This commit is contained in:
Francois Chollet 2023-04-24 14:58:38 -07:00
parent 295f9a5f5f
commit 7fc2f3320e
3 changed files with 100 additions and 34 deletions

@ -11,7 +11,6 @@ from keras_core.models.model import Model
from keras_core.operations.function import Function from keras_core.operations.function import Function
from keras_core.operations.function import make_node_key from keras_core.operations.function import make_node_key
from keras_core.saving import serialization_lib from keras_core.saving import serialization_lib
from keras_core.utils import python_utils
from keras_core.utils import tracking from keras_core.utils import tracking
@ -29,12 +28,7 @@ class Functional(Function, Model):
Symbolic add_loss Symbolic add_loss
""" """
def __new__(cls, *args, **kwargs):
# Skip Model.__new__.
return Function.__new__(cls)
@tracking.no_automatic_dependency_tracking @tracking.no_automatic_dependency_tracking
@python_utils.default
def __init__(self, inputs, outputs, name=None, **kwargs): def __init__(self, inputs, outputs, name=None, **kwargs):
if isinstance(inputs, dict): if isinstance(inputs, dict):
for k, v in inputs.items(): for k, v in inputs.items():
@ -192,7 +186,7 @@ class Functional(Function, Model):
# Subclassed networks are not serializable # Subclassed networks are not serializable
# (unless serialization is implemented by # (unless serialization is implemented by
# the author of the subclassed network). # the author of the subclassed network).
return Model.get_config() return Model.get_config(self)
config = { config = {
"name": self.name, "name": self.name,
@ -266,6 +260,55 @@ class Functional(Function, Model):
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
functional_config_keys = [
"name",
"layers",
"input_layers",
"output_layers",
]
is_functional_config = all(
key in config for key in functional_config_keys
)
argspec = inspect.getfullargspec(cls.__init__)
functional_init_args = inspect.getfullargspec(Functional.__init__).args[
1:
]
revivable_as_functional = (
cls in {Functional, Model}
or argspec.args[1:] == functional_init_args
or (argspec.varargs == "args" and argspec.varkw == "kwargs")
)
if is_functional_config and revivable_as_functional:
# Revive Functional model
# (but not Functional subclasses with a custom __init__)
return cls._from_config(config, custom_objects=custom_objects)
# Either the model has a custom __init__, or the config
# does not contain all the information necessary to
# revive a Functional model. This happens when the user creates
# subclassed models where `get_config()` is returning
# insufficient information to be considered a Functional model.
# In this case, we fall back to provide all config into the
# constructor of the class.
try:
return cls(**config)
except TypeError as e:
raise TypeError(
"Unable to revive model from config. When overriding "
"the `get_config()` method, make sure that the "
"returned config contains all items used as arguments "
f"in the constructor to {cls}, "
"which is the default behavior. "
"You can override this default behavior by defining a "
"`from_config(cls, config)` class method to specify "
"how to create an "
f"instance of {cls.__name__} from its config.\n\n"
f"Received config={config}\n\n"
f"Error encountered during deserialization: {e}"
)
@classmethod
def _from_config(cls, config, custom_objects=None):
"""Instantiates a Model from its config (output of `get_config()`).""" """Instantiates a Model from its config (output of `get_config()`)."""
# Layer instances created during # Layer instances created during
# the graph reconstruction process # the graph reconstruction process

@ -1,7 +1,6 @@
from keras_core import backend from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.layer import Layer
from keras_core.utils import python_utils
from keras_core.utils import summary_utils from keras_core.utils import summary_utils
if backend.backend() == "tensorflow": if backend.backend() == "tensorflow":
@ -32,8 +31,8 @@ class Model(Trainer, Layer):
""" """
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
# Signature detection # Signature detection for usage of `Model` as a `Functional`
if functional_init_arguments(args, kwargs): if functional_init_arguments(args, kwargs) and cls == Model:
from keras_core.models import functional from keras_core.models import functional
return functional.Functional(*args, **kwargs) return functional.Functional(*args, **kwargs)
@ -43,9 +42,10 @@ class Model(Trainer, Layer):
Trainer.__init__(self) Trainer.__init__(self)
from keras_core.models import functional from keras_core.models import functional
if isinstance(self, functional.Functional) and python_utils.is_default( # Signature detection for usage of a `Model` subclass
self.__init__ # as a `Functional` subclass
): if functional_init_arguments(args, kwargs):
inject_functional_model_class(self.__class__)
functional.Functional.__init__(self, *args, **kwargs) functional.Functional.__init__(self, *args, **kwargs)
else: else:
Layer.__init__(self, *args, **kwargs) Layer.__init__(self, *args, **kwargs)
@ -166,3 +166,24 @@ def functional_init_arguments(args, kwargs):
or (len(args) == 1 and "outputs" in kwargs) or (len(args) == 1 and "outputs" in kwargs)
or ("inputs" in kwargs and "outputs" in kwargs) or ("inputs" in kwargs and "outputs" in kwargs)
) )
def inject_functional_model_class(cls):
"""Inject `Functional` into the hierarchy of this class if needed."""
from keras_core.models import functional
if cls == Model:
return functional.Functional
# In case there is any multiple inheritance, we stop injecting the
# class if keras model is not in its class hierarchy.
if cls == object:
return object
cls.__bases__ = tuple(
inject_functional_model_class(base) for base in cls.__bases__
)
# Trigger any `__new__` class swapping that needed to happen on `Functional`
# but did not because functional was not in the class hierarchy.
cls.__new__(cls)
return cls

@ -210,29 +210,31 @@ class SerializationLibTest(testing.TestCase):
new_model.set_weights(model.get_weights()) new_model.set_weights(model.get_weights())
y2 = new_model(x) y2 = new_model(x)
self.assertAllClose(y1, y2, atol=1e-5) self.assertAllClose(y1, y2, atol=1e-5)
# TODO self.assertIsInstance(new_model, PlainFunctionalSubclass)
# self.assertIsInstance(new_model, PlainFunctionalSubclass)
# TODO class FunctionalSubclassWCustomInit(keras_core.Model):
# class FunctionalSubclassWCustomInit(keras_core.Model): def __init__(self, num_units=2):
# def __init__(self, num_units=1, **kwargs): inputs = keras_core.Input((2,), batch_size=3)
# inputs = keras_core.Input((2,), batch_size=3) outputs = keras_core.layers.Dense(num_units)(inputs)
# outputs = keras_core.layers.Dense(num_units)(inputs) super().__init__(inputs, outputs)
# super().__init__(inputs, outputs) self.num_units = num_units
# model = FunctionalSubclassWCustomInit(num_units=2) def get_config(self):
# x = ops.random.normal((2, 2)) return {"num_units": self.num_units}
# y1 = model(x)
# _, new_model, _ = self.roundtrip( model = FunctionalSubclassWCustomInit(num_units=3)
# model, x = ops.random.normal((2, 2))
# custom_objects={ y1 = model(x)
# "FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit _, new_model, _ = self.roundtrip(
# }, model,
# ) custom_objects={
# new_model.set_weights(model.get_weights()) "FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit
# y2 = new_model(x) },
# self.assertAllClose(y1, y2, atol=1e-5) )
# self.assertIsInstance(new_model, FunctionalSubclassWCustomInit) new_model.set_weights(model.get_weights())
y2 = new_model(x)
self.assertAllClose(y1, y2, atol=1e-5)
self.assertIsInstance(new_model, FunctionalSubclassWCustomInit)
def test_shared_object(self): def test_shared_object(self):
class MyLayer(keras_core.layers.Layer): class MyLayer(keras_core.layers.Layer):