Add autoconfig to Operation class.

This commit is contained in:
Francois Chollet 2023-04-23 11:05:04 -07:00
parent a6b690b198
commit 17c51418fa
5 changed files with 194 additions and 9 deletions

@ -4,7 +4,7 @@ from keras_core import testing
class ActivationTest(testing.TestCase):
def test_dense_basics(self):
def test_activation_basics(self):
self.run_layer_test(
layers.Activation,
init_kwargs={

@ -24,15 +24,16 @@ class Functional(Function, Model):
Symbolic add_loss
"""
def __new__(cls, *args, **kwargs):
# Skip Model.__new__.
return Function.__new__(cls)
@tracking.no_automatic_dependency_tracking
def __init__(self, inputs, outputs, name=None, **kwargs):
# This is used by the Model class, since we have some logic to swap the
# class in the __new__ method, which will lead to __init__ get invoked
# twice. Using the skip_init to skip one of the invocation of __init__
# to avoid any side effects
skip_init = kwargs.pop("skip_init", False)
if skip_init:
return
if isinstance(inputs, dict):
for k, v in inputs.items():
if not isinstance(v, backend.KerasTensor):

@ -1,6 +1,9 @@
import inspect
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import python_utils
from keras_core.utils import summary_utils
if backend.backend() == "tensorflow":
@ -32,11 +35,11 @@ class Model(Trainer, Layer):
def __new__(cls, *args, **kwargs):
# Signature detection
if functional_init_arguments(args, kwargs) and cls == Model:
if functional_init_arguments(args, kwargs):
# Functional model
from keras_core.models import functional
return functional.Functional(*args, **kwargs, skip_init=True)
return functional.Functional(*args, **kwargs)
return Layer.__new__(cls)
def __init__(self, trainable=True, name=None, dtype=None):
@ -152,6 +155,32 @@ class Model(Trainer, Layer):
def export(self, filepath):
raise NotImplementedError
@python_utils.default
def get_config(self):
# Prepare base arguments
config = {
"name": self.name,
"trainable": self.trainable,
}
# Check whether the class has a constructor compatible with a Functional
# model or if it has a custom constructor.
if functional_like_constructor(self.__class__):
# Only return a Functional config if the constructor is the same
# as that of a Functional model. This excludes subclassed Functional
# models with a custom __init__.
config = {**config, **get_functional_config(self)}
else:
# Try to autogenerate config
xtra_args = set(config.keys())
if getattr(self, "_auto_get_config", False):
config.update(self._auto_config.config)
# Remove args non explicitly supported
argspec = inspect.getfullargspec(self.__init__)
if argspec.varkw != "kwargs":
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
config.pop(key, None)
return config
def functional_init_arguments(args, kwargs):
return (
@ -159,3 +188,11 @@ def functional_init_arguments(args, kwargs):
or (len(args) == 1 and "outputs" in kwargs)
or ("inputs" in kwargs and "outputs" in kwargs)
)
def functional_like_constructor(cls):
raise NotImplementedError
def get_functional_config(model):
raise NotImplementedError

@ -1,12 +1,26 @@
import inspect
import textwrap
from tensorflow import nest
from keras_core import backend
from keras_core.backend.keras_tensor import any_symbolic_tensors
from keras_core.operations.node import Node
from keras_core.saving import serialization_lib
from keras_core.utils import python_utils
from keras_core.utils.naming import auto_name
class Operation:
def __init__(self, name=None):
self.name = name or auto_name(self.__class__.__name__)
if name is None:
name = auto_name(self.__class__.__name__)
if not isinstance(name, str):
raise ValueError(
"Argument `name` should be a string. "
f"Received instead: name={name} (of type {type(name)})"
)
self.name = name
self._inbound_nodes = []
self._outbound_nodes = []
@ -44,12 +58,123 @@ class Operation:
f"Error encountered: {e}"
)
def __new__(cls, *args, **kwargs):
"""We override __new__ to saving serializable constructor arguments.
These arguments are used to auto-generate an object serialization config,
which enables user-created subclasses to be serializable out of the box
in most cases without forcing the user to manually implement
`get_config()`.
"""
# Generate a config to be returned by default by `get_config()`.
arg_names = inspect.getfullargspec(cls.__init__).args
kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
instance = super(Operation, cls).__new__(cls)
# For safety, we only rely on auto-configs for a small set of
# serializable types.
supported_types = (str, int, float, bool, type(None))
try:
flat_arg_values = nest.flatten(kwargs)
auto_config = True
for value in flat_arg_values:
if not isinstance(value, supported_types):
auto_config = False
break
except TypeError:
auto_config = False
try:
if auto_config:
instance._auto_config = serialization_lib.SerializableDict(
**kwargs
)
else:
instance._auto_config = None
except RecursionError:
# Setting an instance attribute in __new__ has the potential
# to trigger an infinite recursion if a subclass overrides
# setattr in an unsafe way.
pass
return instance
@python_utils.default
def get_config(self):
return {"name": self.name}
"""Returns the config of the object.
An object config is a Python dictionary (serializable)
containing the information needed to re-insstantiate it.
Returns:
Python dictionary.
"""
config = {
"name": self.name,
}
if not python_utils.is_default(self.get_config):
# In this case the subclass implements get_config()
return config
# In this case the subclass doesn't implement get_config():
# Let's see if we can autogenerate it.
if getattr(self, "_auto_config", None) is not None:
xtra_args = set(config.keys())
config.update(self._auto_config.config)
# Remove args non explicitly supported
argspec = inspect.getfullargspec(self.__init__)
if argspec.varkw != "kwargs":
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
config.pop(key, None)
return config
else:
raise NotImplementedError(
textwrap.dedent(
f"""
Object {self.__class__.__name__} was created by passing
non-serializable argument values in `__init__()`,
and therefore the object must override `get_config()` in
order to be serializable. Please implement `get_config()`.
Example:
class CustomLayer(keras.layers.Layer):
def __init__(self, arg1, arg2, **kwargs):
super().__init__(**kwargs)
self.arg1 = arg1
self.arg2 = arg2
def get_config(self):
config = super().get_config()
config.update({
"arg1": self.arg1,
"arg2": self.arg2,
})
return config"""
)
)
@classmethod
def from_config(cls, config):
"""Creates a layer from its config.
This method is the reverse of `get_config`,
capable of instantiating the same layer from the config
dictionary. It does not handle layer connectivity
(handled by Network), nor weights (handled by `set_weights`).
Args:
config: A Python dictionary, typically the
output of get_config.
Returns:
A layer instance.
"""
try:
return cls(**config)
except Exception as e:
raise TypeError(
f"Error when deserializing class '{cls.__name__}' using "
f"config={config}.\n\nException encountered: {e}"
)
def __repr__(self):
return f"<Operation name={self.name}>"

@ -26,6 +26,21 @@ class OpWithMultipleOutputs(operation.Operation):
)
class OpWithCustomConstructor(operation.Operation):
def __init__(self, alpha, mode="foo"):
super().__init__()
self.alpha = alpha
self.mode = mode
def call(self, x):
if self.mode == "foo":
return x
return self.alpha * x
def compute_output_spec(self, x):
return keras_tensor.KerasTensor(x.shape, x.dtype)
class OperationTest(testing.TestCase):
def test_symbolic_call(self):
x = keras_tensor.KerasTensor(shape=(2, 3), name="x")
@ -118,6 +133,13 @@ class OperationTest(testing.TestCase):
op = OpWithMultipleOutputs.from_config(config)
self.assertEqual(op.name, "test_op")
def test_autoconfig(self):
op = OpWithCustomConstructor(alpha=0.2, mode="bar")
config = op.get_config()
self.assertEqual(config, {"alpha": 0.2, "mode": "bar"})
revived = OpWithCustomConstructor.from_config(config)
self.assertEqual(revived.get_config(), config)
def test_input_conversion(self):
x = np.ones((2,))
y = np.ones((2,))