Add Sequential model.
This commit is contained in:
parent
c1ca10160f
commit
4d04065907
@ -1,4 +1,6 @@
|
|||||||
from keras_core.layers.core.dense import Dense
|
from keras_core.layers.core.dense import Dense
|
||||||
|
from keras_core.layers.core.input_layer import Input
|
||||||
|
from keras_core.layers.core.input_layer import InputLayer
|
||||||
from keras_core.layers.layer import Layer
|
from keras_core.layers.layer import Layer
|
||||||
|
|
||||||
# from keras_core.layers.regularization.dropout import Dropout
|
# from keras_core.layers.regularization.dropout import Dropout
|
||||||
|
@ -5,13 +5,33 @@ from keras_core.operations.node import Node
|
|||||||
|
|
||||||
class InputLayer(Layer):
|
class InputLayer(Layer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, shape, batch_size=None, dtype=None, input_tensor=None, name=None
|
self,
|
||||||
|
shape=None,
|
||||||
|
batch_size=None,
|
||||||
|
dtype=None,
|
||||||
|
batch_shape=None,
|
||||||
|
input_tensor=None,
|
||||||
|
name=None,
|
||||||
):
|
):
|
||||||
# TODO: support for sparse, ragged.
|
# TODO: support for sparse, ragged.
|
||||||
super().__init__(name=name)
|
super().__init__(name=name)
|
||||||
self.shape = backend.standardize_shape(shape)
|
if shape is not None and batch_shape is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot pass both `shape` and `batch_shape` at the same time."
|
||||||
|
)
|
||||||
|
if batch_size is not None and batch_shape is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot pass both `batch_size` and `batch_shape` at the same time."
|
||||||
|
)
|
||||||
|
if shape is None and batch_shape is None:
|
||||||
|
raise ValueError("You must pass a `shape` argument.")
|
||||||
|
|
||||||
|
if shape:
|
||||||
|
shape = backend.standardize_shape(shape)
|
||||||
|
batch_shape = (batch_size,) + shape
|
||||||
|
self.batch_shape = batch_shape
|
||||||
self._dtype = backend.standardize_dtype(dtype)
|
self._dtype = backend.standardize_dtype(dtype)
|
||||||
self.batch_size = batch_size
|
|
||||||
if input_tensor is not None:
|
if input_tensor is not None:
|
||||||
if not isinstance(input_tensor, backend.KerasTensor):
|
if not isinstance(input_tensor, backend.KerasTensor):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -20,7 +40,7 @@ class InputLayer(Layer):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_tensor = backend.KerasTensor(
|
input_tensor = backend.KerasTensor(
|
||||||
shape=(batch_size,) + shape, dtype=dtype, name=name
|
shape=batch_shape, dtype=dtype, name=name
|
||||||
)
|
)
|
||||||
self._input_tensor = input_tensor
|
self._input_tensor = input_tensor
|
||||||
Node(operation=self, call_args=(), call_kwargs={}, outputs=input_tensor)
|
Node(operation=self, call_args=(), call_kwargs={}, outputs=input_tensor)
|
||||||
@ -35,15 +55,18 @@ class InputLayer(Layer):
|
|||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {
|
return {
|
||||||
"shape": self.shape,
|
"batch_shape": self.batch_shape,
|
||||||
"batch_size": self.batch_size,
|
|
||||||
"dtype": self.dtype,
|
"dtype": self.dtype,
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def Input(shape=None, batch_size=None, dtype=None, name=None):
|
def Input(shape=None, batch_size=None, dtype=None, batch_shape=None, name=None):
|
||||||
layer = InputLayer(
|
layer = InputLayer(
|
||||||
shape=shape, batch_size=batch_size, dtype=dtype, name=name
|
shape=shape,
|
||||||
|
batch_size=batch_size,
|
||||||
|
dtype=dtype,
|
||||||
|
batch_shape=batch_shape,
|
||||||
|
name=name,
|
||||||
)
|
)
|
||||||
return layer.output
|
return layer.output
|
||||||
|
@ -55,6 +55,9 @@ class Layer(Operation):
|
|||||||
self._non_trainable_variables = []
|
self._non_trainable_variables = []
|
||||||
self._supports_masking = not utils.is_default(self.compute_mask)
|
self._supports_masking = not utils.is_default(self.compute_mask)
|
||||||
self._build_shapes_dict = None
|
self._build_shapes_dict = None
|
||||||
|
self._call_signature_parameters = [
|
||||||
|
p.name for p in inspect.signature(self.call).parameters.values()
|
||||||
|
]
|
||||||
|
|
||||||
self._tracker = Tracker(
|
self._tracker = Tracker(
|
||||||
{
|
{
|
||||||
@ -503,9 +506,10 @@ class Layer(Operation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _call_has_training_arg(self):
|
def _call_has_training_arg(self):
|
||||||
return "training" in [
|
return "training" in self._call_signature_parameters
|
||||||
p.name for p in inspect.signature(self.call).parameters.values()
|
|
||||||
]
|
def _call_has_mask_arg(self):
|
||||||
|
return "mask" in self._call_signature_parameters
|
||||||
|
|
||||||
def _get_call_context(self):
|
def _get_call_context(self):
|
||||||
"""Returns currently active `CallContext`."""
|
"""Returns currently active `CallContext`."""
|
||||||
|
@ -60,7 +60,7 @@ class Functional(Function, Model):
|
|||||||
layers.append(operation)
|
layers.append(operation)
|
||||||
return layers
|
return layers
|
||||||
|
|
||||||
def call(self, inputs, training=False, mask=None):
|
def call(self, inputs, training=None, mask=None):
|
||||||
# Add support for traning, masking
|
# Add support for traning, masking
|
||||||
inputs = self._standardize_inputs(inputs)
|
inputs = self._standardize_inputs(inputs)
|
||||||
if mask is None:
|
if mask is None:
|
||||||
@ -73,7 +73,8 @@ class Functional(Function, Model):
|
|||||||
inputs, operation_fn=lambda op: operation_fn(op, training=training)
|
inputs, operation_fn=lambda op: operation_fn(op, training=training)
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_output_spec(self, inputs, training=False, mask=None):
|
def compute_output_spec(self, inputs, training=None, mask=None):
|
||||||
|
# From Function
|
||||||
return super().compute_output_spec(inputs)
|
return super().compute_output_spec(inputs)
|
||||||
|
|
||||||
def _assert_input_compatibility(self, *args):
|
def _assert_input_compatibility(self, *args):
|
||||||
|
@ -1,11 +1,145 @@
|
|||||||
|
from tensorflow import nest
|
||||||
|
|
||||||
|
from keras_core import backend
|
||||||
from keras_core.api_export import keras_core_export
|
from keras_core.api_export import keras_core_export
|
||||||
|
from keras_core.layers.core.input_layer import InputLayer
|
||||||
|
from keras_core.models.functional import Functional
|
||||||
from keras_core.models.model import Model
|
from keras_core.models.model import Model
|
||||||
|
from keras_core.utils import tracking
|
||||||
|
|
||||||
|
|
||||||
@keras_core_export(["keras_core.Sequential", "keras_core.models.Sequential"])
|
@keras_core_export(["keras_core.Sequential", "keras_core.models.Sequential"])
|
||||||
class Sequential(Model):
|
class Sequential(Model):
|
||||||
def __init__(self, layers, trainable=True, name=None):
|
@tracking.no_automatic_dependency_tracking
|
||||||
pass
|
def __init__(self, layers=None, trainable=True, name=None):
|
||||||
|
super().__init__(trainable=trainable, name=name)
|
||||||
|
self._functional = None
|
||||||
|
self._layers = []
|
||||||
|
if layers:
|
||||||
|
for layer in layers:
|
||||||
|
self.add(layer)
|
||||||
|
|
||||||
def call(self, inputs):
|
def add(self, layer):
|
||||||
pass
|
# If we are passed a Keras tensor created by keras.Input(), we
|
||||||
|
# extract the input layer from its keras history and use that.
|
||||||
|
if hasattr(layer, "_keras_history"):
|
||||||
|
origin_layer = layer._keras_history[0]
|
||||||
|
if isinstance(origin_layer, InputLayer):
|
||||||
|
layer = origin_layer
|
||||||
|
if not self._is_layer_name_unique(layer):
|
||||||
|
raise ValueError(
|
||||||
|
"All layers added to a Sequential model "
|
||||||
|
f"should have unique names. Name '{layer.name}' is already "
|
||||||
|
"the name of a layer in this model. Update the `name` argument "
|
||||||
|
"to pass a unique name."
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
isinstance(layer, InputLayer)
|
||||||
|
and self._layers
|
||||||
|
and isinstance(self._layers[0], InputLayer)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Sequential model '{self.name}' has already been configured to "
|
||||||
|
f"use input shape {self._layers[0].batch_input_shape}. You cannot add "
|
||||||
|
f"a different Input layer to it."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._layers.append(layer)
|
||||||
|
self.built = False
|
||||||
|
self._functional = None
|
||||||
|
|
||||||
|
def pop(self):
|
||||||
|
layer = self._layers.pop()
|
||||||
|
self.built = False
|
||||||
|
self._functional = None
|
||||||
|
return layer
|
||||||
|
|
||||||
|
def build(self, input_shape=None):
|
||||||
|
if not self._layers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Sequential model {self.name} cannot be built because it has no layers. "
|
||||||
|
"Call `model.add(layer)`."
|
||||||
|
)
|
||||||
|
if isinstance(self._layers[0], InputLayer):
|
||||||
|
if self._layers[0].batch_shape != input_shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Sequential model '{self.name}' has already been configured to "
|
||||||
|
f"use input shape {self._layers[0].batch_shape}. You cannot build it "
|
||||||
|
f"with input_shape {input_shape}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._layers = [InputLayer(batch_shape=input_shape)] + self._layers
|
||||||
|
|
||||||
|
# Build functional model
|
||||||
|
inputs = self._layers[0].output
|
||||||
|
x = inputs
|
||||||
|
for layer in self._layers[1:]:
|
||||||
|
x = layer(x)
|
||||||
|
outputs = x
|
||||||
|
self._functional = Functional(inputs=inputs, outputs=outputs)
|
||||||
|
self.built = True
|
||||||
|
|
||||||
|
def call(self, inputs, training=None, mask=None):
|
||||||
|
if self._functional:
|
||||||
|
return self._functional(inputs, training=training, mask=mask)
|
||||||
|
# Else, check if we can build a Functional model
|
||||||
|
if isinstance(inputs, backend.KerasTensor) or backend.is_tensor(inputs):
|
||||||
|
self.build(inputs.shape)
|
||||||
|
return self._functional(inputs, training=training, mask=mask)
|
||||||
|
|
||||||
|
# No functional model can be built -- Just apply the layer sequence.
|
||||||
|
# This typically happens if `inputs` is a nested struct.
|
||||||
|
for layer in self.layers:
|
||||||
|
# During each iteration, `inputs` are the inputs to `layer`, and
|
||||||
|
# `outputs` are the outputs of `layer` applied to `inputs`. At the
|
||||||
|
# end of each iteration `inputs` is set to `outputs` to prepare for
|
||||||
|
# the next layer.
|
||||||
|
kwargs = {}
|
||||||
|
if layer._call_has_mask_arg:
|
||||||
|
kwargs["mask"] = mask
|
||||||
|
if layer._call_has_training_arg:
|
||||||
|
kwargs["training"] = training
|
||||||
|
outputs = layer(inputs, **kwargs)
|
||||||
|
inputs = outputs
|
||||||
|
|
||||||
|
def _get_mask_from_keras_tensor(kt):
|
||||||
|
return getattr(kt, "_keras_mask", None)
|
||||||
|
|
||||||
|
mask = nest.map_structure(_get_mask_from_keras_tensor, outputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
# Historically, `sequential.layers` only returns layers that were added
|
||||||
|
# via `add`, and omits the auto-generated `InputLayer` that comes at the
|
||||||
|
# bottom of the stack.
|
||||||
|
layers = self._layers
|
||||||
|
if layers and isinstance(layers[0], InputLayer):
|
||||||
|
return layers[1:]
|
||||||
|
return layers[:]
|
||||||
|
|
||||||
|
def compute_output_spec(self, inputs, training=None, mask=None):
|
||||||
|
if self._functional:
|
||||||
|
return self._functional.compute_output_spec(
|
||||||
|
inputs, training=training, mask=mask
|
||||||
|
)
|
||||||
|
# Direct application
|
||||||
|
for layer in self.layers:
|
||||||
|
outputs = layer.compute_output_spec(
|
||||||
|
inputs, training=training
|
||||||
|
) # Ignore mask
|
||||||
|
inputs = outputs
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _is_layer_name_unique(self, layer):
|
||||||
|
for ref_layer in self._layers:
|
||||||
|
if layer.name == ref_layer.name and ref_layer is not layer:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
raise NotImplementedError
|
||||||
|
98
keras_core/models/sequential_test.py
Normal file
98
keras_core/models/sequential_test.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from keras_core import backend
|
||||||
|
from keras_core import layers
|
||||||
|
from keras_core import testing
|
||||||
|
from keras_core.layers.core.input_layer import Input
|
||||||
|
from keras_core.models.functional import Functional
|
||||||
|
from keras_core.models.sequential import Sequential
|
||||||
|
|
||||||
|
|
||||||
|
class SequentialTest(testing.TestCase):
|
||||||
|
def test_basic_flow_with_input(self):
|
||||||
|
model = Sequential(name="seq")
|
||||||
|
model.add(Input(shape=(2,), batch_size=3))
|
||||||
|
model.add(layers.Dense(4))
|
||||||
|
model.add(layers.Dense(5))
|
||||||
|
|
||||||
|
self.assertEqual(len(model.layers), 2)
|
||||||
|
|
||||||
|
# Test eager call
|
||||||
|
x = np.random.random((3, 2))
|
||||||
|
y = model(x)
|
||||||
|
self.assertTrue(model.built)
|
||||||
|
self.assertEqual(type(model._functional), Functional)
|
||||||
|
self.assertEqual(y.shape, (3, 5))
|
||||||
|
|
||||||
|
# Test symbolic call
|
||||||
|
x = backend.KerasTensor((3, 2))
|
||||||
|
y = model(x)
|
||||||
|
self.assertEqual(y.shape, (3, 5))
|
||||||
|
|
||||||
|
# Test `layers` constructor arg
|
||||||
|
model = Sequential(
|
||||||
|
layers=[
|
||||||
|
Input(shape=(2,), batch_size=3),
|
||||||
|
layers.Dense(4),
|
||||||
|
layers.Dense(5),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
x = np.random.random((3, 2))
|
||||||
|
y = model(x)
|
||||||
|
self.assertEqual(y.shape, (3, 5))
|
||||||
|
|
||||||
|
# Test pop
|
||||||
|
model.pop()
|
||||||
|
self.assertFalse(model.built)
|
||||||
|
self.assertEqual(model._functional, None)
|
||||||
|
x = np.random.random((3, 2))
|
||||||
|
y = model(x)
|
||||||
|
self.assertTrue(model.built)
|
||||||
|
self.assertEqual(type(model._functional), Functional)
|
||||||
|
self.assertEqual(y.shape, (3, 4))
|
||||||
|
|
||||||
|
def test_basic_flow_deferred(self):
|
||||||
|
model = Sequential(name="seq")
|
||||||
|
model.add(layers.Dense(4))
|
||||||
|
model.add(layers.Dense(5))
|
||||||
|
|
||||||
|
self.assertEqual(len(model.layers), 2)
|
||||||
|
|
||||||
|
# Test eager call
|
||||||
|
x = np.random.random((3, 2))
|
||||||
|
y = model(x)
|
||||||
|
self.assertTrue(model.built)
|
||||||
|
self.assertEqual(type(model._functional), Functional)
|
||||||
|
self.assertEqual(y.shape, (3, 5))
|
||||||
|
|
||||||
|
# Test symbolic call
|
||||||
|
x = backend.KerasTensor((3, 2))
|
||||||
|
y = model(x)
|
||||||
|
self.assertEqual(y.shape, (3, 5))
|
||||||
|
|
||||||
|
# Test `layers` constructor arg
|
||||||
|
model = Sequential(
|
||||||
|
layers=[
|
||||||
|
layers.Dense(4),
|
||||||
|
layers.Dense(5),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
x = np.random.random((3, 2))
|
||||||
|
y = model(x)
|
||||||
|
self.assertEqual(y.shape, (3, 5))
|
||||||
|
|
||||||
|
# Test pop
|
||||||
|
model.pop()
|
||||||
|
self.assertFalse(model.built)
|
||||||
|
self.assertEqual(model._functional, None)
|
||||||
|
x = np.random.random((3, 2))
|
||||||
|
y = model(x)
|
||||||
|
self.assertTrue(model.built)
|
||||||
|
self.assertEqual(type(model._functional), Functional)
|
||||||
|
self.assertEqual(y.shape, (3, 4))
|
||||||
|
|
||||||
|
def test_dict_inputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_serialization(self):
|
||||||
|
pass
|
Loading…
Reference in New Issue
Block a user