Add autoconfig to Operation class.
This commit is contained in:
parent
a6b690b198
commit
17c51418fa
@ -4,7 +4,7 @@ from keras_core import testing
|
||||
|
||||
|
||||
class ActivationTest(testing.TestCase):
|
||||
def test_dense_basics(self):
|
||||
def test_activation_basics(self):
|
||||
self.run_layer_test(
|
||||
layers.Activation,
|
||||
init_kwargs={
|
||||
|
@ -24,15 +24,16 @@ class Functional(Function, Model):
|
||||
Symbolic add_loss
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# Skip Model.__new__.
|
||||
return Function.__new__(cls)
|
||||
|
||||
@tracking.no_automatic_dependency_tracking
|
||||
def __init__(self, inputs, outputs, name=None, **kwargs):
|
||||
# This is used by the Model class, since we have some logic to swap the
|
||||
# class in the __new__ method, which will lead to __init__ get invoked
|
||||
# twice. Using the skip_init to skip one of the invocation of __init__
|
||||
# to avoid any side effects
|
||||
skip_init = kwargs.pop("skip_init", False)
|
||||
if skip_init:
|
||||
return
|
||||
if isinstance(inputs, dict):
|
||||
for k, v in inputs.items():
|
||||
if not isinstance(v, backend.KerasTensor):
|
||||
|
@ -1,6 +1,9 @@
|
||||
import inspect
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import python_utils
|
||||
from keras_core.utils import summary_utils
|
||||
|
||||
if backend.backend() == "tensorflow":
|
||||
@ -32,11 +35,11 @@ class Model(Trainer, Layer):
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# Signature detection
|
||||
if functional_init_arguments(args, kwargs) and cls == Model:
|
||||
if functional_init_arguments(args, kwargs):
|
||||
# Functional model
|
||||
from keras_core.models import functional
|
||||
|
||||
return functional.Functional(*args, **kwargs, skip_init=True)
|
||||
return functional.Functional(*args, **kwargs)
|
||||
return Layer.__new__(cls)
|
||||
|
||||
def __init__(self, trainable=True, name=None, dtype=None):
|
||||
@ -152,6 +155,32 @@ class Model(Trainer, Layer):
|
||||
def export(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
@python_utils.default
|
||||
def get_config(self):
|
||||
# Prepare base arguments
|
||||
config = {
|
||||
"name": self.name,
|
||||
"trainable": self.trainable,
|
||||
}
|
||||
# Check whether the class has a constructor compatible with a Functional
|
||||
# model or if it has a custom constructor.
|
||||
if functional_like_constructor(self.__class__):
|
||||
# Only return a Functional config if the constructor is the same
|
||||
# as that of a Functional model. This excludes subclassed Functional
|
||||
# models with a custom __init__.
|
||||
config = {**config, **get_functional_config(self)}
|
||||
else:
|
||||
# Try to autogenerate config
|
||||
xtra_args = set(config.keys())
|
||||
if getattr(self, "_auto_get_config", False):
|
||||
config.update(self._auto_config.config)
|
||||
# Remove args non explicitly supported
|
||||
argspec = inspect.getfullargspec(self.__init__)
|
||||
if argspec.varkw != "kwargs":
|
||||
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
|
||||
config.pop(key, None)
|
||||
return config
|
||||
|
||||
|
||||
def functional_init_arguments(args, kwargs):
|
||||
return (
|
||||
@ -159,3 +188,11 @@ def functional_init_arguments(args, kwargs):
|
||||
or (len(args) == 1 and "outputs" in kwargs)
|
||||
or ("inputs" in kwargs and "outputs" in kwargs)
|
||||
)
|
||||
|
||||
|
||||
def functional_like_constructor(cls):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_functional_config(model):
|
||||
raise NotImplementedError
|
||||
|
@ -1,12 +1,26 @@
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
from tensorflow import nest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core.backend.keras_tensor import any_symbolic_tensors
|
||||
from keras_core.operations.node import Node
|
||||
from keras_core.saving import serialization_lib
|
||||
from keras_core.utils import python_utils
|
||||
from keras_core.utils.naming import auto_name
|
||||
|
||||
|
||||
class Operation:
|
||||
def __init__(self, name=None):
|
||||
self.name = name or auto_name(self.__class__.__name__)
|
||||
if name is None:
|
||||
name = auto_name(self.__class__.__name__)
|
||||
if not isinstance(name, str):
|
||||
raise ValueError(
|
||||
"Argument `name` should be a string. "
|
||||
f"Received instead: name={name} (of type {type(name)})"
|
||||
)
|
||||
self.name = name
|
||||
self._inbound_nodes = []
|
||||
self._outbound_nodes = []
|
||||
|
||||
@ -44,12 +58,123 @@ class Operation:
|
||||
f"Error encountered: {e}"
|
||||
)
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""We override __new__ to saving serializable constructor arguments.
|
||||
|
||||
These arguments are used to auto-generate an object serialization config,
|
||||
which enables user-created subclasses to be serializable out of the box
|
||||
in most cases without forcing the user to manually implement
|
||||
`get_config()`.
|
||||
"""
|
||||
# Generate a config to be returned by default by `get_config()`.
|
||||
arg_names = inspect.getfullargspec(cls.__init__).args
|
||||
kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
|
||||
instance = super(Operation, cls).__new__(cls)
|
||||
# For safety, we only rely on auto-configs for a small set of
|
||||
# serializable types.
|
||||
supported_types = (str, int, float, bool, type(None))
|
||||
try:
|
||||
flat_arg_values = nest.flatten(kwargs)
|
||||
auto_config = True
|
||||
for value in flat_arg_values:
|
||||
if not isinstance(value, supported_types):
|
||||
auto_config = False
|
||||
break
|
||||
except TypeError:
|
||||
auto_config = False
|
||||
try:
|
||||
if auto_config:
|
||||
instance._auto_config = serialization_lib.SerializableDict(
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
instance._auto_config = None
|
||||
except RecursionError:
|
||||
# Setting an instance attribute in __new__ has the potential
|
||||
# to trigger an infinite recursion if a subclass overrides
|
||||
# setattr in an unsafe way.
|
||||
pass
|
||||
return instance
|
||||
|
||||
@python_utils.default
|
||||
def get_config(self):
|
||||
return {"name": self.name}
|
||||
"""Returns the config of the object.
|
||||
|
||||
An object config is a Python dictionary (serializable)
|
||||
containing the information needed to re-insstantiate it.
|
||||
|
||||
Returns:
|
||||
Python dictionary.
|
||||
"""
|
||||
config = {
|
||||
"name": self.name,
|
||||
}
|
||||
|
||||
if not python_utils.is_default(self.get_config):
|
||||
# In this case the subclass implements get_config()
|
||||
return config
|
||||
|
||||
# In this case the subclass doesn't implement get_config():
|
||||
# Let's see if we can autogenerate it.
|
||||
if getattr(self, "_auto_config", None) is not None:
|
||||
xtra_args = set(config.keys())
|
||||
config.update(self._auto_config.config)
|
||||
# Remove args non explicitly supported
|
||||
argspec = inspect.getfullargspec(self.__init__)
|
||||
if argspec.varkw != "kwargs":
|
||||
for key in xtra_args - xtra_args.intersection(argspec.args[1:]):
|
||||
config.pop(key, None)
|
||||
return config
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
textwrap.dedent(
|
||||
f"""
|
||||
Object {self.__class__.__name__} was created by passing
|
||||
non-serializable argument values in `__init__()`,
|
||||
and therefore the object must override `get_config()` in
|
||||
order to be serializable. Please implement `get_config()`.
|
||||
|
||||
Example:
|
||||
|
||||
class CustomLayer(keras.layers.Layer):
|
||||
def __init__(self, arg1, arg2, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.arg1 = arg1
|
||||
self.arg2 = arg2
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
config.update({
|
||||
"arg1": self.arg1,
|
||||
"arg2": self.arg2,
|
||||
})
|
||||
return config"""
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
"""Creates a layer from its config.
|
||||
|
||||
This method is the reverse of `get_config`,
|
||||
capable of instantiating the same layer from the config
|
||||
dictionary. It does not handle layer connectivity
|
||||
(handled by Network), nor weights (handled by `set_weights`).
|
||||
|
||||
Args:
|
||||
config: A Python dictionary, typically the
|
||||
output of get_config.
|
||||
|
||||
Returns:
|
||||
A layer instance.
|
||||
"""
|
||||
try:
|
||||
return cls(**config)
|
||||
except Exception as e:
|
||||
raise TypeError(
|
||||
f"Error when deserializing class '{cls.__name__}' using "
|
||||
f"config={config}.\n\nException encountered: {e}"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Operation name={self.name}>"
|
||||
|
@ -26,6 +26,21 @@ class OpWithMultipleOutputs(operation.Operation):
|
||||
)
|
||||
|
||||
|
||||
class OpWithCustomConstructor(operation.Operation):
|
||||
def __init__(self, alpha, mode="foo"):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.mode = mode
|
||||
|
||||
def call(self, x):
|
||||
if self.mode == "foo":
|
||||
return x
|
||||
return self.alpha * x
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return keras_tensor.KerasTensor(x.shape, x.dtype)
|
||||
|
||||
|
||||
class OperationTest(testing.TestCase):
|
||||
def test_symbolic_call(self):
|
||||
x = keras_tensor.KerasTensor(shape=(2, 3), name="x")
|
||||
@ -118,6 +133,13 @@ class OperationTest(testing.TestCase):
|
||||
op = OpWithMultipleOutputs.from_config(config)
|
||||
self.assertEqual(op.name, "test_op")
|
||||
|
||||
def test_autoconfig(self):
|
||||
op = OpWithCustomConstructor(alpha=0.2, mode="bar")
|
||||
config = op.get_config()
|
||||
self.assertEqual(config, {"alpha": 0.2, "mode": "bar"})
|
||||
revived = OpWithCustomConstructor.from_config(config)
|
||||
self.assertEqual(revived.get_config(), config)
|
||||
|
||||
def test_input_conversion(self):
|
||||
x = np.ones((2,))
|
||||
y = np.ones((2,))
|
||||
|
Loading…
Reference in New Issue
Block a user