Add the EarlyStopping callback (#44)
* add earlystopping callback * addressing comments * address comments * addressing comments * remove unused imports
This commit is contained in:
parent
1a273b1de6
commit
038d7bb200
@ -1,5 +1,6 @@
|
|||||||
from keras_core.callbacks.callback import Callback
|
from keras_core.callbacks.callback import Callback
|
||||||
from keras_core.callbacks.callback_list import CallbackList
|
from keras_core.callbacks.callback_list import CallbackList
|
||||||
|
from keras_core.callbacks.early_stopping import EarlyStopping
|
||||||
from keras_core.callbacks.history import History
|
from keras_core.callbacks.history import History
|
||||||
from keras_core.callbacks.lambda_callback import LambdaCallback
|
from keras_core.callbacks.lambda_callback import LambdaCallback
|
||||||
from keras_core.callbacks.progbar_logger import ProgbarLogger
|
from keras_core.callbacks.progbar_logger import ProgbarLogger
|
||||||
|
185
keras_core/callbacks/early_stopping.py
Normal file
185
keras_core/callbacks/early_stopping.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
|
from keras_core import operations as ops
|
||||||
|
from keras_core.api_export import keras_core_export
|
||||||
|
from keras_core.callbacks.callback import Callback
|
||||||
|
from keras_core.utils import io_utils
|
||||||
|
|
||||||
|
|
||||||
|
@keras_core_export("keras_core.callbacks.EarlyStopping")
|
||||||
|
class EarlyStopping(Callback):
|
||||||
|
"""Stop training when a monitored metric has stopped improving.
|
||||||
|
|
||||||
|
Assuming the goal of a training is to minimize the loss. With this, the
|
||||||
|
metric to be monitored would be `'loss'`, and mode would be `'min'`. A
|
||||||
|
`model.fit()` training loop will check at end of every epoch whether
|
||||||
|
the loss is no longer decreasing, considering the `min_delta` and
|
||||||
|
`patience` if applicable. Once it's found no longer decreasing,
|
||||||
|
`model.stop_training` is marked True and the training terminates.
|
||||||
|
|
||||||
|
The quantity to be monitored needs to be available in `logs` dict.
|
||||||
|
To make it so, pass the loss or metrics at `model.compile()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
monitor: Quantity to be monitored. Defaults to `"val_loss"`.
|
||||||
|
min_delta: Minimum change in the monitored quantity to qualify as an
|
||||||
|
improvement, i.e. an absolute change of less than min_delta, will
|
||||||
|
count as no improvement. Defaults to `0`.
|
||||||
|
patience: Number of epochs with no improvement after which training will
|
||||||
|
be stopped. Defaults to `0`.
|
||||||
|
verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1 displays
|
||||||
|
messages when the callback takes an action. Defaults to `0`.
|
||||||
|
mode: One of `{"auto", "min", "max"}`. In `min` mode, training will stop
|
||||||
|
when the quantity monitored has stopped decreasing; in `"max"` mode
|
||||||
|
it will stop when the quantity monitored has stopped increasing; in
|
||||||
|
`"auto"` mode, the direction is automatically inferred from the name
|
||||||
|
of the monitored quantity. Defaults to `"auto"`.
|
||||||
|
baseline: Baseline value for the monitored quantity. If not `None`,
|
||||||
|
training will stop if the model doesn't show improvement over the
|
||||||
|
baseline. Defaults to `None`.
|
||||||
|
restore_best_weights: Whether to restore model weights from the epoch
|
||||||
|
with the best value of the monitored quantity. If `False`, the model
|
||||||
|
weights obtained at the last step of training are used. An epoch
|
||||||
|
will be restored regardless of the performance relative to the
|
||||||
|
`baseline`. If no epoch improves on `baseline`, training will run
|
||||||
|
for `patience` epochs and restore weights from the best epoch in
|
||||||
|
that set. Defaults to `False`.
|
||||||
|
start_from_epoch: Number of epochs to wait before starting to monitor
|
||||||
|
improvement. This allows for a warm-up period in which no
|
||||||
|
improvement is expected and thus training will not be stopped.
|
||||||
|
Defaults to `0`.
|
||||||
|
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> callback = keras_core.callbacks.EarlyStopping(monitor='loss',
|
||||||
|
... patience=3)
|
||||||
|
>>> # This callback will stop the training when there is no improvement in
|
||||||
|
>>> # the loss for three consecutive epochs.
|
||||||
|
>>> model = keras_core.models.Sequential([keras_core.layers.Dense(10)])
|
||||||
|
>>> model.compile(keras_core.optimizers.SGD(), loss='mse')
|
||||||
|
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
|
||||||
|
... epochs=10, batch_size=1, callbacks=[callback],
|
||||||
|
... verbose=0)
|
||||||
|
>>> len(history.history['loss']) # Only 4 epochs are run.
|
||||||
|
4
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
monitor="val_loss",
|
||||||
|
min_delta=0,
|
||||||
|
patience=0,
|
||||||
|
verbose=0,
|
||||||
|
mode="auto",
|
||||||
|
baseline=None,
|
||||||
|
restore_best_weights=False,
|
||||||
|
start_from_epoch=0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.monitor = monitor
|
||||||
|
self.patience = patience
|
||||||
|
self.verbose = verbose
|
||||||
|
self.baseline = baseline
|
||||||
|
self.min_delta = abs(min_delta)
|
||||||
|
self.wait = 0
|
||||||
|
self.stopped_epoch = 0
|
||||||
|
self.restore_best_weights = restore_best_weights
|
||||||
|
self.best_weights = None
|
||||||
|
self.start_from_epoch = start_from_epoch
|
||||||
|
|
||||||
|
if mode not in ["auto", "min", "max"]:
|
||||||
|
warnings.warn(
|
||||||
|
f"EarlyStopping mode {mode} is unknown, fallback to auto mode.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
mode = "auto"
|
||||||
|
|
||||||
|
if mode == "min":
|
||||||
|
self.monitor_op = ops.less
|
||||||
|
elif mode == "max":
|
||||||
|
self.monitor_op = ops.greater
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
self.monitor.endswith("acc")
|
||||||
|
or self.monitor.endswith("accuracy")
|
||||||
|
or self.monitor.endswith("auc")
|
||||||
|
):
|
||||||
|
self.monitor_op = ops.greater
|
||||||
|
else:
|
||||||
|
self.monitor_op = ops.less
|
||||||
|
|
||||||
|
if self.monitor_op == ops.greater:
|
||||||
|
self.min_delta *= 1
|
||||||
|
else:
|
||||||
|
self.min_delta *= -1
|
||||||
|
|
||||||
|
def on_train_begin(self, logs=None):
|
||||||
|
# Allow instances to be re-used
|
||||||
|
self.wait = 0
|
||||||
|
self.stopped_epoch = 0
|
||||||
|
self.best = (
|
||||||
|
float("inf") if self.monitor_op == ops.less else -float("inf")
|
||||||
|
)
|
||||||
|
self.best_weights = None
|
||||||
|
self.best_epoch = 0
|
||||||
|
|
||||||
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
|
current = self.get_monitor_value(logs)
|
||||||
|
if current is None or epoch < self.start_from_epoch:
|
||||||
|
# If no monitor value exists or still in initial warm-up stage.
|
||||||
|
return
|
||||||
|
if self.restore_best_weights and self.best_weights is None:
|
||||||
|
# Restore the weights after first epoch if no progress is ever made.
|
||||||
|
self.best_weights = self.model.get_weights()
|
||||||
|
|
||||||
|
self.wait += 1
|
||||||
|
if self._is_improvement(current, self.best):
|
||||||
|
self.best = current
|
||||||
|
self.best_epoch = epoch
|
||||||
|
if self.restore_best_weights:
|
||||||
|
self.best_weights = self.model.get_weights()
|
||||||
|
# Only restart wait if we beat both the baseline and our previous
|
||||||
|
# best.
|
||||||
|
if self.baseline is None or self._is_improvement(
|
||||||
|
current, self.baseline
|
||||||
|
):
|
||||||
|
self.wait = 0
|
||||||
|
return
|
||||||
|
|
||||||
|
# Only check after the first epoch.
|
||||||
|
if self.wait >= self.patience and epoch > 0:
|
||||||
|
self.stopped_epoch = epoch
|
||||||
|
self.model.stop_training = True
|
||||||
|
if self.restore_best_weights and self.best_weights is not None:
|
||||||
|
if self.verbose > 0:
|
||||||
|
io_utils.print_msg(
|
||||||
|
"Restoring model weights from "
|
||||||
|
"the end of the best epoch: "
|
||||||
|
f"{self.best_epoch + 1}."
|
||||||
|
)
|
||||||
|
self.model.set_weights(self.best_weights)
|
||||||
|
|
||||||
|
def on_train_end(self, logs=None):
|
||||||
|
if self.stopped_epoch > 0 and self.verbose > 0:
|
||||||
|
io_utils.print_msg(
|
||||||
|
f"Epoch {self.stopped_epoch + 1}: early stopping"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_monitor_value(self, logs):
|
||||||
|
logs = logs or {}
|
||||||
|
monitor_value = logs.get(self.monitor)
|
||||||
|
if monitor_value is None:
|
||||||
|
warnings.warn(
|
||||||
|
(
|
||||||
|
f"Early stopping conditioned on metric `{self.monitor}` "
|
||||||
|
"which is not available. "
|
||||||
|
f"Available metrics are: {','.join(list(logs.keys()))}"
|
||||||
|
),
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return monitor_value
|
||||||
|
|
||||||
|
def _is_improvement(self, monitor_value, reference_value):
|
||||||
|
return self.monitor_op(monitor_value - self.min_delta, reference_value)
|
207
keras_core/callbacks/early_stopping_test.py
Normal file
207
keras_core/callbacks/early_stopping_test.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from keras_core import callbacks
|
||||||
|
from keras_core import layers
|
||||||
|
from keras_core import models
|
||||||
|
from keras_core import testing
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyStoppingTest(testing.TestCase):
|
||||||
|
def test_early_stopping(self):
|
||||||
|
x_train = np.random.random((10, 5))
|
||||||
|
y_train = np.random.random((10, 1))
|
||||||
|
x_test = np.random.random((10, 5))
|
||||||
|
y_test = np.random.random((10, 1))
|
||||||
|
model = models.Sequential(
|
||||||
|
(
|
||||||
|
layers.Dense(1, activation="relu"),
|
||||||
|
layers.Dense(1, activation="relu"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model.compile(
|
||||||
|
loss="mae",
|
||||||
|
optimizer="adam",
|
||||||
|
metrics=["mse"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cases = [
|
||||||
|
("max", "val_mse"),
|
||||||
|
("min", "val_loss"),
|
||||||
|
("auto", "val_mse"),
|
||||||
|
("auto", "loss"),
|
||||||
|
("unknown", "unknown"),
|
||||||
|
]
|
||||||
|
for mode, monitor in cases:
|
||||||
|
patience = 0
|
||||||
|
cbks = [
|
||||||
|
callbacks.EarlyStopping(
|
||||||
|
patience=patience, monitor=monitor, mode=mode
|
||||||
|
)
|
||||||
|
]
|
||||||
|
model.fit(
|
||||||
|
x_train,
|
||||||
|
y_train,
|
||||||
|
batch_size=5,
|
||||||
|
validation_data=(x_test, y_test),
|
||||||
|
callbacks=cbks,
|
||||||
|
epochs=5,
|
||||||
|
verbose=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_early_stopping_patience(self):
|
||||||
|
cases = [0, 1, 2, 3]
|
||||||
|
losses = [10.0, 9.0, 8.0, 9.0, 8.9, 8.8, 8.7, 8.6, 8.5]
|
||||||
|
|
||||||
|
for patience in cases:
|
||||||
|
stopper = callbacks.EarlyStopping(monitor="loss", patience=patience)
|
||||||
|
stopper.model = models.Sequential()
|
||||||
|
stopper.model.compile(loss="mse", optimizer="sgd")
|
||||||
|
stopper.on_train_begin()
|
||||||
|
|
||||||
|
for epoch, loss in enumerate(losses):
|
||||||
|
stopper.on_epoch_end(epoch=epoch, logs={"loss": loss})
|
||||||
|
if stopper.model.stop_training:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertEqual(stopper.stopped_epoch, max(patience, 1) + 2)
|
||||||
|
|
||||||
|
def test_early_stopping_reuse(self):
|
||||||
|
patience = 3
|
||||||
|
data = np.random.random((100, 1))
|
||||||
|
labels = np.where(data > 0.5, 1, 0)
|
||||||
|
model = models.Sequential(
|
||||||
|
(
|
||||||
|
layers.Dense(1, activation="relu"),
|
||||||
|
layers.Dense(1, activation="relu"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model.compile(
|
||||||
|
optimizer="sgd",
|
||||||
|
loss="mae",
|
||||||
|
metrics=["mse"],
|
||||||
|
)
|
||||||
|
weights = model.get_weights()
|
||||||
|
|
||||||
|
# This should allow training to go for at least `patience` epochs
|
||||||
|
model.set_weights(weights)
|
||||||
|
|
||||||
|
stopper = callbacks.EarlyStopping(monitor="mse", patience=patience)
|
||||||
|
hist = model.fit(
|
||||||
|
data, labels, callbacks=[stopper], verbose=0, epochs=20
|
||||||
|
)
|
||||||
|
assert len(hist.epoch) >= patience
|
||||||
|
|
||||||
|
def test_early_stopping_with_baseline(self):
|
||||||
|
baseline = 0.6
|
||||||
|
x_train = np.random.random((10, 5))
|
||||||
|
y_train = np.random.random((10, 1))
|
||||||
|
model = models.Sequential(
|
||||||
|
(
|
||||||
|
layers.Dense(1, activation="relu"),
|
||||||
|
layers.Dense(1, activation="relu"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model.compile(optimizer="sgd", loss="mae", metrics=["mse"])
|
||||||
|
|
||||||
|
patience = 3
|
||||||
|
stopper = callbacks.EarlyStopping(
|
||||||
|
monitor="mse", patience=patience, baseline=baseline
|
||||||
|
)
|
||||||
|
hist = model.fit(
|
||||||
|
x_train, y_train, callbacks=[stopper], verbose=0, epochs=20
|
||||||
|
)
|
||||||
|
assert len(hist.epoch) >= patience
|
||||||
|
|
||||||
|
def test_early_stopping_final_weights_when_restoring_model_weights(self):
|
||||||
|
class DummyModel:
|
||||||
|
def __init__(self):
|
||||||
|
self.stop_training = False
|
||||||
|
self.weights = -1
|
||||||
|
|
||||||
|
def get_weights(self):
|
||||||
|
return self.weights
|
||||||
|
|
||||||
|
def set_weights(self, weights):
|
||||||
|
self.weights = weights
|
||||||
|
|
||||||
|
def set_weight_to_epoch(self, epoch):
|
||||||
|
self.weights = epoch
|
||||||
|
|
||||||
|
early_stop = callbacks.EarlyStopping(
|
||||||
|
monitor="val_loss", patience=2, restore_best_weights=True
|
||||||
|
)
|
||||||
|
early_stop.model = DummyModel()
|
||||||
|
losses = [0.2, 0.15, 0.1, 0.11, 0.12]
|
||||||
|
# The best configuration is in the epoch 2 (loss = 0.1000).
|
||||||
|
epochs_trained = 0
|
||||||
|
early_stop.on_train_begin()
|
||||||
|
for epoch in range(len(losses)):
|
||||||
|
epochs_trained += 1
|
||||||
|
early_stop.model.set_weight_to_epoch(epoch=epoch)
|
||||||
|
early_stop.on_epoch_end(epoch, logs={"val_loss": losses[epoch]})
|
||||||
|
if early_stop.model.stop_training:
|
||||||
|
break
|
||||||
|
# The best configuration is in epoch 2 (loss = 0.1000),
|
||||||
|
# and while patience = 2, we're restoring the best weights,
|
||||||
|
# so we end up at the epoch with the best weights, i.e. epoch 2
|
||||||
|
self.assertEqual(early_stop.model.get_weights(), 2)
|
||||||
|
|
||||||
|
# Check early stopping when no model beats the baseline.
|
||||||
|
early_stop = callbacks.EarlyStopping(
|
||||||
|
monitor="val_loss",
|
||||||
|
patience=5,
|
||||||
|
baseline=0.5,
|
||||||
|
restore_best_weights=True,
|
||||||
|
)
|
||||||
|
early_stop.model = DummyModel()
|
||||||
|
losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73]
|
||||||
|
# The best configuration is in the epoch 2 (loss = 0.7000).
|
||||||
|
epochs_trained = 0
|
||||||
|
early_stop.on_train_begin()
|
||||||
|
for epoch in range(len(losses)):
|
||||||
|
epochs_trained += 1
|
||||||
|
early_stop.model.set_weight_to_epoch(epoch=epoch)
|
||||||
|
early_stop.on_epoch_end(epoch, logs={"val_loss": losses[epoch]})
|
||||||
|
if early_stop.model.stop_training:
|
||||||
|
break
|
||||||
|
# No epoch improves on the baseline, so we should train for only 5
|
||||||
|
# epochs, and restore the second model.
|
||||||
|
self.assertEqual(epochs_trained, 5)
|
||||||
|
self.assertEqual(early_stop.model.get_weights(), 2)
|
||||||
|
|
||||||
|
def test_early_stopping_with_start_from_epoch(self):
|
||||||
|
x_train = np.random.random((10, 5))
|
||||||
|
y_train = np.random.random((10, 1))
|
||||||
|
model = models.Sequential(
|
||||||
|
(
|
||||||
|
layers.Dense(1, activation="relu"),
|
||||||
|
layers.Dense(1, activation="relu"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model.compile(optimizer="sgd", loss="mae", metrics=["mse"])
|
||||||
|
start_from_epoch = 2
|
||||||
|
patience = 3
|
||||||
|
stopper = callbacks.EarlyStopping(
|
||||||
|
monitor="mse",
|
||||||
|
patience=patience,
|
||||||
|
start_from_epoch=start_from_epoch,
|
||||||
|
)
|
||||||
|
history = model.fit(
|
||||||
|
x_train, y_train, callbacks=[stopper], verbose=0, epochs=20
|
||||||
|
)
|
||||||
|
# Test 'patience' argument functions correctly when used
|
||||||
|
# in conjunction with 'start_from_epoch'.
|
||||||
|
self.assertGreaterEqual(len(history.epoch), patience + start_from_epoch)
|
||||||
|
|
||||||
|
start_from_epoch = 2
|
||||||
|
patience = 0
|
||||||
|
stopper = callbacks.EarlyStopping(
|
||||||
|
monitor="mse",
|
||||||
|
patience=patience,
|
||||||
|
start_from_epoch=start_from_epoch,
|
||||||
|
)
|
||||||
|
history = model.fit(
|
||||||
|
x_train, y_train, callbacks=[stopper], verbose=0, epochs=20
|
||||||
|
)
|
||||||
|
# Test for boundary condition when 'patience' = 0.
|
||||||
|
self.assertGreaterEqual(len(history.epoch), start_from_epoch)
|
@ -1,5 +1,8 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
from keras_core import backend
|
from keras_core import backend
|
||||||
from keras_core import operations as ops
|
from keras_core import operations as ops
|
||||||
|
from keras_core.api_export import keras_core_export
|
||||||
|
|
||||||
|
|
||||||
def l2_normalize(x, axis=0):
|
def l2_normalize(x, axis=0):
|
||||||
@ -7,3 +10,61 @@ def l2_normalize(x, axis=0):
|
|||||||
square_sum = ops.sum(ops.square(x), axis=axis, keepdims=True)
|
square_sum = ops.sum(ops.square(x), axis=axis, keepdims=True)
|
||||||
l2_norm = ops.reciprocal(ops.sqrt(ops.maximum(square_sum, epsilon)))
|
l2_norm = ops.reciprocal(ops.sqrt(ops.maximum(square_sum, epsilon)))
|
||||||
return ops.multiply(x, l2_norm)
|
return ops.multiply(x, l2_norm)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_core_export("keras_core.utils.to_categorical")
|
||||||
|
def to_categorical(x, num_classes=None):
|
||||||
|
"""Converts a class vector (integers) to binary class matrix.
|
||||||
|
|
||||||
|
E.g. for use with `categorical_crossentropy`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Array-like with class values to be converted into a matrix
|
||||||
|
(integers from 0 to `num_classes - 1`).
|
||||||
|
num_classes: Total number of classes. If `None`, this would be inferred
|
||||||
|
as `max(x) + 1`. Defaults to `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A binary matrix representation of the input as a NumPy array. The class
|
||||||
|
axis is placed last.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> a = keras_core.utils.to_categorical([0, 1, 2, 3], num_classes=4)
|
||||||
|
>>> print(a)
|
||||||
|
[[1. 0. 0. 0.]
|
||||||
|
[0. 1. 0. 0.]
|
||||||
|
[0. 0. 1. 0.]
|
||||||
|
[0. 0. 0. 1.]]
|
||||||
|
|
||||||
|
>>> b = np.array([.9, .04, .03, .03,
|
||||||
|
... .3, .45, .15, .13,
|
||||||
|
... .04, .01, .94, .05,
|
||||||
|
... .12, .21, .5, .17],
|
||||||
|
... shape=[4, 4])
|
||||||
|
>>> loss = keras_core.backend.categorical_crossentropy(a, b)
|
||||||
|
>>> print(np.around(loss, 5))
|
||||||
|
[0.10536 0.82807 0.1011 1.77196]
|
||||||
|
|
||||||
|
>>> loss = keras_core.backend.categorical_crossentropy(a, a)
|
||||||
|
>>> print(np.around(loss, 5))
|
||||||
|
[0. 0. 0. 0.]
|
||||||
|
"""
|
||||||
|
if backend.is_tensor(x):
|
||||||
|
return backend.nn.one_hot(x, num_classes)
|
||||||
|
x = np.array(x, dtype="int64")
|
||||||
|
input_shape = x.shape
|
||||||
|
|
||||||
|
# Shrink the last dimension if the shape is (..., 1).
|
||||||
|
if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
|
||||||
|
input_shape = tuple(input_shape[:-1])
|
||||||
|
|
||||||
|
x = x.reshape(-1)
|
||||||
|
if not num_classes:
|
||||||
|
num_classes = np.max(x) + 1
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
categorical = np.zeros((batch_size, num_classes))
|
||||||
|
categorical[np.arange(batch_size), x] = 1
|
||||||
|
output_shape = input_shape + (num_classes,)
|
||||||
|
categorical = np.reshape(categorical, output_shape)
|
||||||
|
return categorical
|
||||||
|
54
keras_core/utils/numerical_utils_test.py
Normal file
54
keras_core/utils/numerical_utils_test.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import numpy as np
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from keras_core import backend
|
||||||
|
from keras_core import testing
|
||||||
|
from keras_core.utils import numerical_utils
|
||||||
|
|
||||||
|
NUM_CLASSES = 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestNumericalUtils(testing.TestCase, parameterized.TestCase):
|
||||||
|
@parameterized.parameters(
|
||||||
|
[
|
||||||
|
((1,), (1, NUM_CLASSES)),
|
||||||
|
((3,), (3, NUM_CLASSES)),
|
||||||
|
((4, 3), (4, 3, NUM_CLASSES)),
|
||||||
|
((5, 4, 3), (5, 4, 3, NUM_CLASSES)),
|
||||||
|
((3, 1), (3, NUM_CLASSES)),
|
||||||
|
((3, 2, 1), (3, 2, NUM_CLASSES)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_to_categorical(self, shape, expected_shape):
|
||||||
|
label = np.random.randint(0, NUM_CLASSES, shape)
|
||||||
|
one_hot = numerical_utils.to_categorical(label, NUM_CLASSES)
|
||||||
|
# Check shape
|
||||||
|
self.assertEqual(one_hot.shape, expected_shape)
|
||||||
|
# Make sure there is only one 1 in a row
|
||||||
|
self.assertTrue(np.all(one_hot.sum(axis=-1) == 1))
|
||||||
|
# Get original labels back from one hots
|
||||||
|
self.assertTrue(
|
||||||
|
np.all(np.argmax(one_hot, -1).reshape(label.shape) == label)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_to_categorial_without_num_classes(self):
|
||||||
|
label = [0, 2, 5]
|
||||||
|
one_hot = numerical_utils.to_categorical(label)
|
||||||
|
self.assertEqual(one_hot.shape, (3, 5 + 1))
|
||||||
|
|
||||||
|
def test_to_categorical_with_backend_tensor(self):
|
||||||
|
label = backend.convert_to_tensor(np.array([0, 2, 1, 3, 4]))
|
||||||
|
expected = backend.convert_to_tensor(
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
[1, 0, 0, 0, 0],
|
||||||
|
[0, 0, 1, 0, 0],
|
||||||
|
[0, 1, 0, 0, 0],
|
||||||
|
[0, 0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
one_hot = numerical_utils.to_categorical(label, NUM_CLASSES)
|
||||||
|
assert backend.is_tensor(one_hot)
|
||||||
|
self.assertAllClose(one_hot, expected)
|
@ -62,4 +62,5 @@ per-file-ignores =
|
|||||||
keras_core/models/functional.py:E501
|
keras_core/models/functional.py:E501
|
||||||
keras_core/layers/layer.py:E501
|
keras_core/layers/layer.py:E501
|
||||||
keras_core/initializers/random_initializers.py:E501
|
keras_core/initializers/random_initializers.py:E501
|
||||||
|
keras_core/saving/saving_lib_test.py:E501
|
||||||
max-line-length = 80
|
max-line-length = 80
|
||||||
|
Loading…
Reference in New Issue
Block a user