Finally make functional subclassing work
This commit is contained in:
parent
295f9a5f5f
commit
7fc2f3320e
@ -11,7 +11,6 @@ from keras_core.models.model import Model
|
|||||||
from keras_core.operations.function import Function
|
from keras_core.operations.function import Function
|
||||||
from keras_core.operations.function import make_node_key
|
from keras_core.operations.function import make_node_key
|
||||||
from keras_core.saving import serialization_lib
|
from keras_core.saving import serialization_lib
|
||||||
from keras_core.utils import python_utils
|
|
||||||
from keras_core.utils import tracking
|
from keras_core.utils import tracking
|
||||||
|
|
||||||
|
|
||||||
@ -29,12 +28,7 @@ class Functional(Function, Model):
|
|||||||
Symbolic add_loss
|
Symbolic add_loss
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
|
||||||
# Skip Model.__new__.
|
|
||||||
return Function.__new__(cls)
|
|
||||||
|
|
||||||
@tracking.no_automatic_dependency_tracking
|
@tracking.no_automatic_dependency_tracking
|
||||||
@python_utils.default
|
|
||||||
def __init__(self, inputs, outputs, name=None, **kwargs):
|
def __init__(self, inputs, outputs, name=None, **kwargs):
|
||||||
if isinstance(inputs, dict):
|
if isinstance(inputs, dict):
|
||||||
for k, v in inputs.items():
|
for k, v in inputs.items():
|
||||||
@ -192,7 +186,7 @@ class Functional(Function, Model):
|
|||||||
# Subclassed networks are not serializable
|
# Subclassed networks are not serializable
|
||||||
# (unless serialization is implemented by
|
# (unless serialization is implemented by
|
||||||
# the author of the subclassed network).
|
# the author of the subclassed network).
|
||||||
return Model.get_config()
|
return Model.get_config(self)
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
@ -266,6 +260,55 @@ class Functional(Function, Model):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config, custom_objects=None):
|
def from_config(cls, config, custom_objects=None):
|
||||||
|
functional_config_keys = [
|
||||||
|
"name",
|
||||||
|
"layers",
|
||||||
|
"input_layers",
|
||||||
|
"output_layers",
|
||||||
|
]
|
||||||
|
is_functional_config = all(
|
||||||
|
key in config for key in functional_config_keys
|
||||||
|
)
|
||||||
|
argspec = inspect.getfullargspec(cls.__init__)
|
||||||
|
functional_init_args = inspect.getfullargspec(Functional.__init__).args[
|
||||||
|
1:
|
||||||
|
]
|
||||||
|
revivable_as_functional = (
|
||||||
|
cls in {Functional, Model}
|
||||||
|
or argspec.args[1:] == functional_init_args
|
||||||
|
or (argspec.varargs == "args" and argspec.varkw == "kwargs")
|
||||||
|
)
|
||||||
|
if is_functional_config and revivable_as_functional:
|
||||||
|
# Revive Functional model
|
||||||
|
# (but not Functional subclasses with a custom __init__)
|
||||||
|
return cls._from_config(config, custom_objects=custom_objects)
|
||||||
|
|
||||||
|
# Either the model has a custom __init__, or the config
|
||||||
|
# does not contain all the information necessary to
|
||||||
|
# revive a Functional model. This happens when the user creates
|
||||||
|
# subclassed models where `get_config()` is returning
|
||||||
|
# insufficient information to be considered a Functional model.
|
||||||
|
# In this case, we fall back to provide all config into the
|
||||||
|
# constructor of the class.
|
||||||
|
try:
|
||||||
|
return cls(**config)
|
||||||
|
except TypeError as e:
|
||||||
|
raise TypeError(
|
||||||
|
"Unable to revive model from config. When overriding "
|
||||||
|
"the `get_config()` method, make sure that the "
|
||||||
|
"returned config contains all items used as arguments "
|
||||||
|
f"in the constructor to {cls}, "
|
||||||
|
"which is the default behavior. "
|
||||||
|
"You can override this default behavior by defining a "
|
||||||
|
"`from_config(cls, config)` class method to specify "
|
||||||
|
"how to create an "
|
||||||
|
f"instance of {cls.__name__} from its config.\n\n"
|
||||||
|
f"Received config={config}\n\n"
|
||||||
|
f"Error encountered during deserialization: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_config(cls, config, custom_objects=None):
|
||||||
"""Instantiates a Model from its config (output of `get_config()`)."""
|
"""Instantiates a Model from its config (output of `get_config()`)."""
|
||||||
# Layer instances created during
|
# Layer instances created during
|
||||||
# the graph reconstruction process
|
# the graph reconstruction process
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from keras_core import backend
|
from keras_core import backend
|
||||||
from keras_core.api_export import keras_core_export
|
from keras_core.api_export import keras_core_export
|
||||||
from keras_core.layers.layer import Layer
|
from keras_core.layers.layer import Layer
|
||||||
from keras_core.utils import python_utils
|
|
||||||
from keras_core.utils import summary_utils
|
from keras_core.utils import summary_utils
|
||||||
|
|
||||||
if backend.backend() == "tensorflow":
|
if backend.backend() == "tensorflow":
|
||||||
@ -32,8 +31,8 @@ class Model(Trainer, Layer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
# Signature detection
|
# Signature detection for usage of `Model` as a `Functional`
|
||||||
if functional_init_arguments(args, kwargs):
|
if functional_init_arguments(args, kwargs) and cls == Model:
|
||||||
from keras_core.models import functional
|
from keras_core.models import functional
|
||||||
|
|
||||||
return functional.Functional(*args, **kwargs)
|
return functional.Functional(*args, **kwargs)
|
||||||
@ -43,9 +42,10 @@ class Model(Trainer, Layer):
|
|||||||
Trainer.__init__(self)
|
Trainer.__init__(self)
|
||||||
from keras_core.models import functional
|
from keras_core.models import functional
|
||||||
|
|
||||||
if isinstance(self, functional.Functional) and python_utils.is_default(
|
# Signature detection for usage of a `Model` subclass
|
||||||
self.__init__
|
# as a `Functional` subclass
|
||||||
):
|
if functional_init_arguments(args, kwargs):
|
||||||
|
inject_functional_model_class(self.__class__)
|
||||||
functional.Functional.__init__(self, *args, **kwargs)
|
functional.Functional.__init__(self, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
Layer.__init__(self, *args, **kwargs)
|
Layer.__init__(self, *args, **kwargs)
|
||||||
@ -166,3 +166,24 @@ def functional_init_arguments(args, kwargs):
|
|||||||
or (len(args) == 1 and "outputs" in kwargs)
|
or (len(args) == 1 and "outputs" in kwargs)
|
||||||
or ("inputs" in kwargs and "outputs" in kwargs)
|
or ("inputs" in kwargs and "outputs" in kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def inject_functional_model_class(cls):
|
||||||
|
"""Inject `Functional` into the hierarchy of this class if needed."""
|
||||||
|
from keras_core.models import functional
|
||||||
|
|
||||||
|
if cls == Model:
|
||||||
|
return functional.Functional
|
||||||
|
# In case there is any multiple inheritance, we stop injecting the
|
||||||
|
# class if keras model is not in its class hierarchy.
|
||||||
|
if cls == object:
|
||||||
|
return object
|
||||||
|
|
||||||
|
cls.__bases__ = tuple(
|
||||||
|
inject_functional_model_class(base) for base in cls.__bases__
|
||||||
|
)
|
||||||
|
# Trigger any `__new__` class swapping that needed to happen on `Functional`
|
||||||
|
# but did not because functional was not in the class hierarchy.
|
||||||
|
cls.__new__(cls)
|
||||||
|
|
||||||
|
return cls
|
||||||
|
@ -210,29 +210,31 @@ class SerializationLibTest(testing.TestCase):
|
|||||||
new_model.set_weights(model.get_weights())
|
new_model.set_weights(model.get_weights())
|
||||||
y2 = new_model(x)
|
y2 = new_model(x)
|
||||||
self.assertAllClose(y1, y2, atol=1e-5)
|
self.assertAllClose(y1, y2, atol=1e-5)
|
||||||
# TODO
|
self.assertIsInstance(new_model, PlainFunctionalSubclass)
|
||||||
# self.assertIsInstance(new_model, PlainFunctionalSubclass)
|
|
||||||
|
|
||||||
# TODO
|
class FunctionalSubclassWCustomInit(keras_core.Model):
|
||||||
# class FunctionalSubclassWCustomInit(keras_core.Model):
|
def __init__(self, num_units=2):
|
||||||
# def __init__(self, num_units=1, **kwargs):
|
inputs = keras_core.Input((2,), batch_size=3)
|
||||||
# inputs = keras_core.Input((2,), batch_size=3)
|
outputs = keras_core.layers.Dense(num_units)(inputs)
|
||||||
# outputs = keras_core.layers.Dense(num_units)(inputs)
|
super().__init__(inputs, outputs)
|
||||||
# super().__init__(inputs, outputs)
|
self.num_units = num_units
|
||||||
|
|
||||||
# model = FunctionalSubclassWCustomInit(num_units=2)
|
def get_config(self):
|
||||||
# x = ops.random.normal((2, 2))
|
return {"num_units": self.num_units}
|
||||||
# y1 = model(x)
|
|
||||||
# _, new_model, _ = self.roundtrip(
|
model = FunctionalSubclassWCustomInit(num_units=3)
|
||||||
# model,
|
x = ops.random.normal((2, 2))
|
||||||
# custom_objects={
|
y1 = model(x)
|
||||||
# "FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit
|
_, new_model, _ = self.roundtrip(
|
||||||
# },
|
model,
|
||||||
# )
|
custom_objects={
|
||||||
# new_model.set_weights(model.get_weights())
|
"FunctionalSubclassWCustomInit": FunctionalSubclassWCustomInit
|
||||||
# y2 = new_model(x)
|
},
|
||||||
# self.assertAllClose(y1, y2, atol=1e-5)
|
)
|
||||||
# self.assertIsInstance(new_model, FunctionalSubclassWCustomInit)
|
new_model.set_weights(model.get_weights())
|
||||||
|
y2 = new_model(x)
|
||||||
|
self.assertAllClose(y1, y2, atol=1e-5)
|
||||||
|
self.assertIsInstance(new_model, FunctionalSubclassWCustomInit)
|
||||||
|
|
||||||
def test_shared_object(self):
|
def test_shared_object(self):
|
||||||
class MyLayer(keras_core.layers.Layer):
|
class MyLayer(keras_core.layers.Layer):
|
||||||
|
Loading…
Reference in New Issue
Block a user