Add discretization layer wrapper.

This commit is contained in:
Francois Chollet 2023-05-08 14:42:13 -07:00
parent 3d3c29bc78
commit 9f3428568e
6 changed files with 282 additions and 51 deletions

@ -53,6 +53,7 @@ from keras_core.layers.pooling.max_pooling1d import MaxPooling1D
from keras_core.layers.pooling.max_pooling2d import MaxPooling2D
from keras_core.layers.pooling.max_pooling3d import MaxPooling3D
from keras_core.layers.preprocessing.center_crop import CenterCrop
from keras_core.layers.preprocessing.discretization import Discretization
from keras_core.layers.preprocessing.normalization import Normalization
from keras_core.layers.preprocessing.rescaling import Rescaling
from keras_core.layers.preprocessing.resizing import Resizing

@ -0,0 +1,185 @@
import numpy as np
import tensorflow as tf
from keras_core import backend
from keras_core.layers.layer import Layer
class Discretization(Layer):
"""A preprocessing layer which buckets continuous features by ranges.
This layer will place each element of its input data into one of several
contiguous ranges and output an integer index indicating which range each
element was placed in.
**Note:** This layer wraps `tf.keras.layers.Discretization`. It cannot
be used as part of the compiled computation graph of a model with
any backend other than TensorFlow.
It can however be used with any backend when running eagerly.
It can also always be used as part of an input preprocessing pipeline
with any backend (outside the model itself), which is how we recommend
to use this layer.
Input shape:
Any array of dimension 2 or higher.
Output shape:
Same as input shape.
Arguments:
bin_boundaries: A list of bin boundaries.
The leftmost and rightmost bins
will always extend to `-inf` and `inf`,
so `bin_boundaries=[0., 1., 2.]`
generates bins `(-inf, 0.)`, `[0., 1.)`, `[1., 2.)`,
and `[2., +inf)`.
If this option is set, `adapt()` should not be called.
num_bins: The integer number of bins to compute.
If this option is set,
`adapt()` should be called to learn the bin boundaries.
epsilon: Error tolerance, typically a small fraction
close to zero (e.g. 0.01). Higher values of epsilon increase
the quantile approximation, and hence result in more
unequal buckets, but could improve performance
and resource consumption.
output_mode: Specification for the output of the layer.
Values can be `"int"`, `"one_hot"`, `"multi_hot"`, or
`"count"` configuring the layer as follows:
- `"int"`: Return the discretized bin indices directly.
- `"one_hot"`: Encodes each individual element in the
input into an array the same size as `num_bins`,
containing a 1 at the input's bin
index. If the last dimension is size 1, will encode on that
dimension. If the last dimension is not size 1,
will append a new dimension for the encoded output.
- `"multi_hot"`: Encodes each sample in the input into a
single array the same size as `num_bins`,
containing a 1 for each bin index
index present in the sample.
Treats the last dimension as the sample
dimension, if input shape is `(..., sample_length)`,
output shape will be `(..., num_tokens)`.
- `"count"`: As `"multi_hot"`, but the int array contains
a count of the number of times the bin index appeared
in the sample.
Defaults to `"int"`.
sparse: Boolean. Only applicable to `"one_hot"`, `"multi_hot"`,
and `"count"` output modes. Only supported with TensorFlow
backend. If `True`, returns a `SparseTensor` instead of
a dense `Tensor`. Defaults to `False`.
Examples:
Bucketize float values based on provided buckets.
>>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])
>>> layer = Discretization(bin_boundaries=[0., 1., 2.])
>>> layer(input)
array([[0, 2, 3, 1],
[1, 3, 2, 1]])
Bucketize float values based on a number of buckets to compute.
>>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])
>>> layer = Discretization(num_bins=4, epsilon=0.01)
>>> layer.adapt(input)
>>> layer(input)
array([[0, 2, 3, 2],
[1, 3, 3, 1]])
"""
def __init__(
self,
bin_boundaries=None,
num_bins=None,
epsilon=0.01,
output_mode="int",
sparse=False,
name=None,
):
super().__init__(name=name)
if sparse and backend.backend() != "tensorflow":
raise ValueError()
self.layer = tf.keras.layers.Discretization(
bin_boundaries=bin_boundaries,
num_bins=num_bins,
epsilon=epsilon,
output_mode=output_mode,
sparse=sparse,
name=name,
)
self.bin_boundaries = (
bin_boundaries if bin_boundaries is not None else []
)
self.num_bins = num_bins
self.epsilon = epsilon
self.output_mode = output_mode
self.sparse = sparse
def build(self, input_shape):
self.layer.build(input_shape)
self.built = True
# We override this method solely to generate a docstring.
def adapt(self, data, batch_size=None, steps=None):
"""Computes bin boundaries from quantiles in a input dataset.
Calling `adapt()` on a `Discretization` layer is an alternative to
passing in a `bin_boundaries` argument during construction. A
`Discretization` layer should always be either adapted over a dataset or
passed `bin_boundaries`.
During `adapt()`, the layer will estimate the quantile boundaries of the
input dataset. The number of quantiles can be controlled via the
`num_bins` argument, and the error tolerance for quantile boundaries can
be controlled via the `epsilon` argument.
Arguments:
data: The data to train on. It can be passed either as a
`tf.data.Dataset`, or as a numpy array.
batch_size: Integer or `None`.
Number of samples per state update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` if your data is in the
form of datasets, generators, or `keras.utils.Sequence` instances
(since they generate batches).
steps: Integer or `None`.
Total number of steps (batches of samples)
When training with input tensors such as
TensorFlow data tensors, the default `None` is equal to
the number of samples in your dataset divided by
the batch size, or 1 if that cannot be determined. If x is a
`tf.data.Dataset`, and `steps` is `None`, the epoch will run until
the input dataset is exhausted. When passing an infinitely
repeating dataset, you must specify the `steps` argument. This
argument is not supported with array inputs.
"""
self.layer.adapt(data, batch_size=batch_size, steps=steps)
def update_state(self, data):
self.layer.update_state(data)
def finalize_state(self):
self.layer.finalize_state()
def reset_state(self):
self.layer.reset_state()
def get_config(self):
return {
"bin_boundaries": self.bin_boundaries,
"num_bins": self.num_bins,
"epsilon": self.epsilon,
"output_mode": self.output_mode,
"sparse": self.sparse,
"name": self.name,
}
def compute_output_shape(self, input_shape):
return input_shape
def call(self, inputs):
if not isinstance(inputs, (tf.Tensor, np.ndarray)):
inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow":
outputs = backend.convert_to_tensor(outputs)
return outputs

@ -0,0 +1,37 @@
import numpy as np
from keras_core import backend
from keras_core import layers
from keras_core import testing
class DicretizationTest(testing.TestCase):
def test_discretization_basics(self):
self.run_layer_test(
layers.Discretization,
init_kwargs={
"bin_boundaries": [0.0, 0.5, 1.0],
},
input_shape=(2, 3),
expected_output_shape=(2, 3),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=False,
)
def test_adapt_flow(self):
layer = layers.Discretization(num_bins=4)
layer.adapt(
np.random.random((32, 3)),
batch_size=8,
)
output = layer(np.array([[0.0, 0.1, 0.3]]))
self.assertTrue(output.dtype, "int32")
def test_correctness(self):
layer = layers.Discretization(bin_boundaries=[0.0, 0.5, 1.0])
output = layer(np.array([[0.0, 0.1, 0.8]]))
self.assertTrue(backend.is_tensor(output))
self.assertAllClose(output, np.array([[1, 1, 2]]))

@ -1,16 +0,0 @@
from keras_core.optimizers.schedules.learning_rate_schedule import CosineDecay
from keras_core.optimizers.schedules.learning_rate_schedule import (
CosineDecayRestarts,
)
from keras_core.optimizers.schedules.learning_rate_schedule import (
ExponentialDecay,
)
from keras_core.optimizers.schedules.learning_rate_schedule import (
InverseTimeDecay,
)
from keras_core.optimizers.schedules.learning_rate_schedule import (
PiecewiseConstantDecay,
)
from keras_core.optimizers.schedules.learning_rate_schedule import (
PolynomialDecay,
)

@ -6,13 +6,13 @@ import numpy as np
from keras_core import backend
from keras_core import testing
from keras_core.optimizers import schedules
from keras_core.optimizers.schedules import learning_rate_schedule
class ExponentialDecayTest(testing.TestCase):
def test_config(self):
self.run_class_serialization_test(
schedules.ExponentialDecay(
learning_rate_schedule.ExponentialDecay(
initial_learning_rate=0.05,
decay_steps=10,
decay_rate=0.96,
@ -23,13 +23,15 @@ class ExponentialDecayTest(testing.TestCase):
def test_continuous(self):
step = 5
decayed_lr = schedules.ExponentialDecay(0.05, 10, 0.96)
decayed_lr = learning_rate_schedule.ExponentialDecay(0.05, 10, 0.96)
expected = 0.05 * 0.96 ** (5.0 / 10.0)
self.assertAllClose(decayed_lr(step), expected, 1e-6)
def test_staircase(self):
step = backend.Variable(1)
decayed_lr = schedules.ExponentialDecay(0.1, 3, 0.96, staircase=True)
decayed_lr = learning_rate_schedule.ExponentialDecay(
0.1, 3, 0.96, staircase=True
)
# No change to learning rate due to staircase
expected = 0.1
@ -46,7 +48,9 @@ class ExponentialDecayTest(testing.TestCase):
def test_variables(self):
step = backend.Variable(1)
decayed_lr = schedules.ExponentialDecay(0.1, 3, 0.96, staircase=True)
decayed_lr = learning_rate_schedule.ExponentialDecay(
0.1, 3, 0.96, staircase=True
)
# No change to learning rate
step.assign(1)
@ -62,14 +66,14 @@ class ExponentialDecayTest(testing.TestCase):
class PiecewiseConstantDecayTest(testing.TestCase):
def test_config(self):
self.run_class_serialization_test(
schedules.PiecewiseConstantDecay(
learning_rate_schedule.PiecewiseConstantDecay(
boundaries=[10, 20], values=[1, 2, 3], name="my_pcd"
)
)
def test_piecewise_values(self):
x = backend.Variable(-999)
decayed_lr = schedules.PiecewiseConstantDecay(
decayed_lr = learning_rate_schedule.PiecewiseConstantDecay(
[100, 110, 120], [1.0, 0.1, 0.01, 0.001]
)
@ -89,7 +93,9 @@ class PiecewiseConstantDecayTest(testing.TestCase):
# Test casting boundaries from int32 to int64.
x_int64 = backend.Variable(0, dtype="int64")
boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7]
decayed_lr = schedules.PiecewiseConstantDecay(boundaries, values)
decayed_lr = learning_rate_schedule.PiecewiseConstantDecay(
boundaries, values
)
self.assertAllClose(decayed_lr(x_int64), 0.4, 1e-6)
x_int64.assign(1)
@ -105,7 +111,7 @@ class PiecewiseConstantDecayTest(testing.TestCase):
class LinearDecayTest(testing.TestCase):
def test_config(self):
self.run_class_serialization_test(
schedules.PolynomialDecay(
learning_rate_schedule.PolynomialDecay(
initial_learning_rate=0.1,
decay_steps=100,
end_learning_rate=0.005,
@ -119,7 +125,7 @@ class LinearDecayTest(testing.TestCase):
step = 5
lr = 0.05
end_lr = 0.0
decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr)
decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr)
expected = lr * 0.5
self.assertAllClose(decayed_lr(step), expected, 1e-6)
@ -127,7 +133,7 @@ class LinearDecayTest(testing.TestCase):
step = 10
lr = 0.05
end_lr = 0.001
decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr)
decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr)
expected = end_lr
self.assertAllClose(decayed_lr(step), expected, 1e-6)
@ -135,7 +141,7 @@ class LinearDecayTest(testing.TestCase):
step = 5
lr = 0.05
end_lr = 0.001
decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr)
decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr)
expected = (lr + end_lr) * 0.5
self.assertAllClose(decayed_lr(step), expected, 1e-6)
@ -143,7 +149,7 @@ class LinearDecayTest(testing.TestCase):
step = 15
lr = 0.05
end_lr = 0.001
decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr)
decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr)
expected = end_lr
self.assertAllClose(decayed_lr(step), expected, 1e-6)
@ -151,7 +157,9 @@ class LinearDecayTest(testing.TestCase):
step = 15
lr = 0.05
end_lr = 0.001
decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, cycle=True)
decayed_lr = learning_rate_schedule.PolynomialDecay(
lr, 10, end_lr, cycle=True
)
expected = (lr - end_lr) * 0.25 + end_lr
self.assertAllClose(decayed_lr(step), expected, 1e-6)
@ -162,7 +170,9 @@ class SqrtDecayTest(testing.TestCase):
lr = 0.05
end_lr = 0.0
power = 0.5
decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, power=power)
decayed_lr = learning_rate_schedule.PolynomialDecay(
lr, 10, end_lr, power=power
)
expected = lr * 0.5**power
self.assertAllClose(decayed_lr(step), expected, 1e-6)
@ -171,7 +181,9 @@ class SqrtDecayTest(testing.TestCase):
lr = 0.05
end_lr = 0.001
power = 0.5
decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, power=power)
decayed_lr = learning_rate_schedule.PolynomialDecay(
lr, 10, end_lr, power=power
)
expected = end_lr
self.assertAllClose(decayed_lr(step), expected, 1e-6)
@ -180,7 +192,9 @@ class SqrtDecayTest(testing.TestCase):
lr = 0.05
end_lr = 0.001
power = 0.5
decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, power=power)
decayed_lr = learning_rate_schedule.PolynomialDecay(
lr, 10, end_lr, power=power
)
expected = (lr - end_lr) * 0.5**power + end_lr
self.assertAllClose(decayed_lr(step), expected, 1e-6)
@ -189,7 +203,9 @@ class SqrtDecayTest(testing.TestCase):
lr = 0.05
end_lr = 0.001
power = 0.5
decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, power=power)
decayed_lr = learning_rate_schedule.PolynomialDecay(
lr, 10, end_lr, power=power
)
expected = end_lr
self.assertAllClose(decayed_lr(step), expected, 1e-6)
@ -198,7 +214,7 @@ class SqrtDecayTest(testing.TestCase):
lr = 0.05
end_lr = 0.001
power = 0.5
decayed_lr = schedules.PolynomialDecay(
decayed_lr = learning_rate_schedule.PolynomialDecay(
lr, 10, end_lr, power=power, cycle=True
)
expected = (lr - end_lr) * 0.25**power + end_lr
@ -208,7 +224,9 @@ class SqrtDecayTest(testing.TestCase):
lr = 0.001
decay_steps = 10
step = 0
decayed_lr = schedules.PolynomialDecay(lr, decay_steps, cycle=True)
decayed_lr = learning_rate_schedule.PolynomialDecay(
lr, decay_steps, cycle=True
)
expected = lr
self.assertAllClose(decayed_lr(step), expected, 1e-6)
@ -216,7 +234,7 @@ class SqrtDecayTest(testing.TestCase):
class InverseTimeDecayTest(testing.TestCase):
def test_config(self):
self.run_class_serialization_test(
schedules.InverseTimeDecay(
learning_rate_schedule.InverseTimeDecay(
initial_learning_rate=0.05,
decay_steps=10,
decay_rate=0.96,
@ -230,7 +248,9 @@ class InverseTimeDecayTest(testing.TestCase):
k = 10
decay_rate = 0.96
step = backend.Variable(0)
decayed_lr = schedules.InverseTimeDecay(initial_lr, k, decay_rate)
decayed_lr = learning_rate_schedule.InverseTimeDecay(
initial_lr, k, decay_rate
)
for i in range(k + 1):
expected = initial_lr / (1 + i / k * decay_rate)
@ -242,7 +262,7 @@ class InverseTimeDecayTest(testing.TestCase):
k = 10
decay_rate = 0.96
step = backend.Variable(0)
decayed_lr = schedules.InverseTimeDecay(
decayed_lr = learning_rate_schedule.InverseTimeDecay(
initial_lr, k, decay_rate, staircase=True
)
@ -255,7 +275,7 @@ class InverseTimeDecayTest(testing.TestCase):
class CosineDecayTest(testing.TestCase):
def test_config(self):
self.run_class_serialization_test(
schedules.CosineDecay(
learning_rate_schedule.CosineDecay(
initial_learning_rate=0.05,
decay_steps=10,
alpha=0.1,
@ -275,7 +295,9 @@ class CosineDecayTest(testing.TestCase):
num_training_steps = 1000
initial_lr = 1.0
for step in range(0, 1500, 250):
decayed_lr = schedules.CosineDecay(initial_lr, num_training_steps)
decayed_lr = learning_rate_schedule.CosineDecay(
initial_lr, num_training_steps
)
expected = self.np_cosine_decay(step, num_training_steps)
self.assertAllClose(decayed_lr(step), expected, 1e-6)
@ -289,7 +311,7 @@ class CosineDecayTest(testing.TestCase):
initial_lr = 0.0
target_lr = 10.0
for step in range(0, 1500, 250):
lr = schedules.CosineDecay(
lr = learning_rate_schedule.CosineDecay(
initial_lr,
0,
warmup_target=target_lr,
@ -305,7 +327,7 @@ class CosineDecayTest(testing.TestCase):
initial_lr = 1.0
alpha = 0.1
for step in range(0, 1500, 250):
decayed_lr = schedules.CosineDecay(
decayed_lr = learning_rate_schedule.CosineDecay(
initial_lr, num_training_steps, alpha
)
expected = self.np_cosine_decay(step, num_training_steps, alpha)
@ -315,7 +337,9 @@ class CosineDecayTest(testing.TestCase):
num_training_steps = 1000
initial_lr = np.float64(1.0)
for step in range(0, 1500, 250):
decayed_lr = schedules.CosineDecay(initial_lr, num_training_steps)
decayed_lr = learning_rate_schedule.CosineDecay(
initial_lr, num_training_steps
)
expected = self.np_cosine_decay(step, num_training_steps)
self.assertAllClose(decayed_lr(step), expected, 1e-6)
@ -325,7 +349,7 @@ class CosineDecayTest(testing.TestCase):
initial_lr = 0.0
target_lr = 10.0
for step in range(0, 3000, 250):
lr = schedules.CosineDecay(
lr = learning_rate_schedule.CosineDecay(
initial_lr,
decay_steps,
warmup_target=target_lr,
@ -345,7 +369,7 @@ class CosineDecayTest(testing.TestCase):
class CosineDecayRestartsTest(testing.TestCase):
def test_config(self):
self.run_class_serialization_test(
schedules.CosineDecayRestarts(
learning_rate_schedule.CosineDecayRestarts(
initial_learning_rate=0.05,
first_decay_steps=10,
alpha=0.1,
@ -372,7 +396,7 @@ class CosineDecayRestartsTest(testing.TestCase):
num_training_steps = 1000
initial_lr = 1.0
for step in range(0, 1500, 250):
decayed_lr = schedules.CosineDecayRestarts(
decayed_lr = learning_rate_schedule.CosineDecayRestarts(
initial_lr, num_training_steps
)
expected = self.np_cosine_decay_restarts(step, num_training_steps)
@ -382,7 +406,7 @@ class CosineDecayRestartsTest(testing.TestCase):
num_training_steps = 1000
initial_lr = np.float64(1.0)
for step in range(0, 1500, 250):
decayed_lr = schedules.CosineDecayRestarts(
decayed_lr = learning_rate_schedule.CosineDecayRestarts(
initial_lr, num_training_steps
)
expected = self.np_cosine_decay_restarts(step, num_training_steps)
@ -393,7 +417,7 @@ class CosineDecayRestartsTest(testing.TestCase):
initial_lr = 1.0
alpha = 0.1
for step in range(0, 1500, 250):
decayed_lr = schedules.CosineDecayRestarts(
decayed_lr = learning_rate_schedule.CosineDecayRestarts(
initial_lr, num_training_steps, alpha=alpha
)
expected = self.np_cosine_decay_restarts(
@ -406,7 +430,7 @@ class CosineDecayRestartsTest(testing.TestCase):
initial_lr = 1.0
m_mul = 0.9
for step in range(0, 1500, 250):
decayed_lr = schedules.CosineDecayRestarts(
decayed_lr = learning_rate_schedule.CosineDecayRestarts(
initial_lr, num_training_steps, m_mul=m_mul
)
expected = self.np_cosine_decay_restarts(
@ -419,7 +443,7 @@ class CosineDecayRestartsTest(testing.TestCase):
initial_lr = 1.0
t_mul = 1.0
for step in range(0, 1500, 250):
decayed_lr = schedules.CosineDecayRestarts(
decayed_lr = learning_rate_schedule.CosineDecayRestarts(
initial_lr, num_training_steps, t_mul=t_mul
)
expected = self.np_cosine_decay_restarts(