Finally make functional subclassing work
This commit is contained in:
parent
295f9a5f5f
commit
7fc2f3320e
@ -11,7 +11,6 @@ from keras_core.models.model import Model
|
||||
from keras_core.operations.function import Function
|
||||
from keras_core.operations.function import make_node_key
|
||||
from keras_core.saving import serialization_lib
|
||||
from keras_core.utils import python_utils
|
||||
from keras_core.utils import tracking
|
||||
|
||||
|
||||
@ -29,12 +28,7 @@ class Functional(Function, Model):
|
||||
Symbolic add_loss
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# Skip Model.__new__.
|
||||
return Function.__new__(cls)
|
||||
|
||||
@tracking.no_automatic_dependency_tracking
|
||||
@python_utils.default
|
||||
def __init__(self, inputs, outputs, name=None, **kwargs):
|
||||
if isinstance(inputs, dict):
|
||||
for k, v in inputs.items():
|
||||
@ -192,7 +186,7 @@ class Functional(Function, Model):
|
||||
# Subclassed networks are not serializable
|
||||
# (unless serialization is implemented by
|
||||
# the author of the subclassed network).
|
||||
return Model.get_config()
|
||||
return Model.get_config(self)
|
||||
|
||||
config = {
|
||||
"name": self.name,
|
||||
@ -266,6 +260,55 @@ class Functional(Function, Model):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
functional_config_keys = [
|
||||
"name",
|
||||
"layers",
|
||||
"input_layers",
|
||||
"output_layers",
|
||||
]
|
||||
is_functional_config = all(
|
||||
key in config for key in functional_config_keys
|
||||
)
|
||||
argspec = inspect.getfullargspec(cls.__init__)
|
||||
functional_init_args = inspect.getfullargspec(Functional.__init__).args[
|
||||
1:
|
||||
]
|
||||
revivable_as_functional = (
|
||||
cls in {Functional, Model}
|
||||
or argspec.args[1:] == functional_init_args
|
||||
or (argspec.varargs == "args" and argspec.varkw == "kwargs")
|
||||
)
|
||||
if is_functional_config and revivable_as_functional:
|
||||
# Revive Functional model
|
||||
# (but not Functional subclasses with a custom __init__)
|
||||
return cls._from_config(config, custom_objects=custom_objects)
|
||||
|
||||
# Either the model has a custom __init__, or the config
|
||||
# does not contain all the information necessary to
|
||||
# revive a Functional model. This happens when the user creates
|
||||
# subclassed models where `get_config()` is returning
|
||||
# insufficient information to be considered a Functional model.
|
||||
# In this case, we fall back to provide all config into the
|
||||
# constructor of the class.
|
||||
try:
|
||||
return cls(**config)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"Unable to revive model from config. When overriding "
|
||||
"the `get_config()` method, make sure that the "
|
||||
"returned config contains all items used as arguments "
|
||||
f"in the constructor to {cls}, "
|
||||
"which is the default behavior. "
|
||||
"You can override this default behavior by defining a "
|
||||
"`from_config(cls, config)` class method to specify "
|
||||
"how to create an "
|
||||
f"instance of {cls.__name__} from its config.\n\n"
|
||||
f"Received config={config}\n\n"
|
||||
f"Error encountered during deserialization: {e}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, custom_objects=None):
|
||||
"""Instantiates a Model from its config (output of `get_config()`)."""
|
||||
# Layer instances created during
|
||||
# the graph reconstruction process
|
||||
|
@ -1,7 +1,6 @@
|
||||
from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import python_utils
|
||||
from keras_core.utils import summary_utils
|
||||
|
||||
if backend.backend() == "tensorflow":
|
||||
@ -32,8 +31,8 @@ class Model(Trainer, Layer):
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# Signature detection
|
||||
if functional_init_arguments(args, kwargs):
|
||||
# Signature detection for usage of `Model` as a `Functional`
|
||||
if functional_init_arguments(args, kwargs) and cls == Model:
|
||||
from keras_core.models import functional
|
||||
|
||||
return functional.Functional(*args, **kwargs)
|
||||
@ -43,9 +42,10 @@ class Model(Trainer, Layer):
|
||||
Trainer.__init__(self)
|
||||
from keras_core.models import functional
|
||||
|
||||
if isinstance(self, functional.Functional) and python_utils.is_default(
|
||||
self.__init__
|
||||
):
|
||||
# Signature detection for usage of a `Model` subclass
|
||||
# as a `Functional` subclass
|
||||
if functional_init_arguments(args, kwargs):
|
||||
inject_functional_model_class(self.__class__)
|
||||
functional.Functional.__init__(self, *args, **kwargs)
|
||||
else:
|
||||
Layer.__init__(self, *args, **kwargs)
|
||||
@ -166,3 +166,24 @@ def functional_init_arguments(args, kwargs):
|
||||
or (len(args) == 1 and "outputs" in kwargs)
|
||||
or ("inputs" in kwargs and "outputs" in kwargs)
|
||||
)
|
||||
|
||||
|
||||
def inject_functional_model_class(cls):
|
||||
"""Inject `Functional` into the hierarchy of this class if needed."""
|
||||
from keras_core.models import functional
|
||||
|
||||
if cls == Model:
|
||||
return functional.Functional
|
||||
# In case there is any multiple inheritance, we stop injecting the
|
||||
# class if keras model is not in its class hierarchy.
|
||||
if cls == object:
|
||||
return object
|
||||
|
||||
cls.__bases__ = tuple(
|
||||
inject_functional_model_class(base) for base in cls.__bases__
|
||||
)
|
||||
# Trigger any `__new__` class swapping that needed to happen on `Functional`
|
||||
# but did not because functional was not in the class hierarchy.
|
||||
cls.__new__(cls)
|
||||
|
||||
return cls
|
||||
|
@ -210,29 +210,31 @@ class SerializationLibTest(testing.TestCase):
|
||||
new_model.set_weights(model.get_weights())
|
||||
y2 = new_model(x)
|
||||
self.assertAllClose(y1, y2, atol=1e-5)
|
||||
# TODO
|
||||
# self.assertIsInstance(new_model, PlainFunctionalSubclass)
|
||||
self.assertIsInstance(new_model, PlainFunctionalSubclass)
|
||||
|
||||
# TODO
|
||||
# class FunctionalSubclassWCustomInit(keras_core.Model):
|
||||
# def __init__(self, num_units=1, **kwargs):
|
||||
# inputs = keras_core.Input((2,), batch_size=3)
|
||||
# outputs = keras_core.layers.Dense(num_units)(inputs)
|
||||
# super().__init__(inputs, outputs)
|
||||
class FunctionalSubclassWCustomInit(keras_core.Model):
|
||||
def __init__(self, num_units=2):
|
||||
inputs = keras_core.Input((2,), batch_size=3)
|
||||
outputs = keras_core.layers.Dense(num_units)(inputs)
|
||||
super().__init__(inputs, outputs)
|
||||
self.num_units = num_units
|
||||
|
||||
# model = FunctionalSubclassWCustomInit(num_units=2)
|
||||
# x = ops.random.normal((2, 2))
|
||||
# y1 = model(x)
|
||||
# _, new_model, _ = self.roundtrip(
|
||||
# model,
|
||||
# custom_objects={
|
||||
# "FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit
|
||||
# },
|
||||
# )
|
||||
# new_model.set_weights(model.get_weights())
|
||||
# y2 = new_model(x)
|
||||
# self.assertAllClose(y1, y2, atol=1e-5)
|
||||
# self.assertIsInstance(new_model, FunctionalSubclassWCustomInit)
|
||||
def get_config(self):
|
||||
return {"num_units": self.num_units}
|
||||
|
||||
model = FunctionalSubclassWCustomInit(num_units=3)
|
||||
x = ops.random.normal((2, 2))
|
||||
y1 = model(x)
|
||||
_, new_model, _ = self.roundtrip(
|
||||
model,
|
||||
custom_objects={
|
||||
"FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit
|
||||
},
|
||||
)
|
||||
new_model.set_weights(model.get_weights())
|
||||
y2 = new_model(x)
|
||||
self.assertAllClose(y1, y2, atol=1e-5)
|
||||
self.assertIsInstance(new_model, FunctionalSubclassWCustomInit)
|
||||
|
||||
def test_shared_object(self):
|
||||
class MyLayer(keras_core.layers.Layer):
|
||||
|
Loading…
Reference in New Issue
Block a user