keras/keras_core/callbacks/csv_logger_test.py
Ramesh Sampath cc89199f1e Add CSVLogger and TerminateOnNaN Callbacks (#95)
* Add CSV Logger and Terminate on Nan

* Add CSVLogger and Terminate on Nan tests

* Update CSV Logger docstring
2023-05-06 00:09:26 -05:00

177 lines
5.6 KiB
Python

import csv
import os
import re
import tempfile
import numpy as np
from keras_core import callbacks
from keras_core import initializers
from keras_core import layers
from keras_core import testing
from keras_core.models import Sequential
from keras_core.utils import numerical_utils
TRAIN_SAMPLES = 10
TEST_SAMPLES = 10
INPUT_DIM = 3
BATCH_SIZE = 4
class CSVLoggerTest(testing.TestCase):
def test_CSVLogger(self):
OUTPUT_DIM = 1
np.random.seed(1337)
temp_dir = tempfile.TemporaryDirectory()
filepath = os.path.join(temp_dir.name, "log.tsv")
sep = "\t"
x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM))
y_train = np.random.random((TRAIN_SAMPLES, OUTPUT_DIM))
x_test = np.random.random((TEST_SAMPLES, INPUT_DIM))
y_test = np.random.random((TEST_SAMPLES, OUTPUT_DIM))
def make_model():
np.random.seed(1337)
model = Sequential(
[
layers.Dense(2, activation="relu"),
layers.Dense(OUTPUT_DIM),
]
)
model.compile(
loss="mse",
optimizer="sgd",
metrics=["mse"],
)
return model
# case 1, create new file with defined separator
model = make_model()
cbks = [callbacks.CSVLogger(filepath, separator=sep)]
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
callbacks=cbks,
epochs=1,
verbose=0,
)
assert os.path.exists(filepath)
with open(filepath) as csvfile:
dialect = csv.Sniffer().sniff(csvfile.read())
assert dialect.delimiter == sep
del model
del cbks
# case 2, append data to existing file, skip header
model = make_model()
cbks = [callbacks.CSVLogger(filepath, separator=sep, append=True)]
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
callbacks=cbks,
epochs=1,
verbose=0,
)
# case 3, reuse of CSVLogger object
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
callbacks=cbks,
epochs=2,
verbose=0,
)
with open(filepath) as csvfile:
list_lines = csvfile.readlines()
for line in list_lines:
assert line.count(sep) == 4
assert len(list_lines) == 5
output = " ".join(list_lines)
assert len(re.findall("epoch", output)) == 1
os.remove(filepath)
# case 3, Verify Val. loss also registered when Validation Freq > 1
model = make_model()
cbks = [callbacks.CSVLogger(filepath, separator=sep)]
hist = model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
validation_freq=3,
callbacks=cbks,
epochs=5,
verbose=0,
)
assert os.path.exists(filepath)
# Verify that validation loss is registered at val. freq
with open(filepath) as csvfile:
rows = csv.DictReader(csvfile, delimiter=sep)
for idx, row in enumerate(rows, 1):
self.assertIn("val_loss", row)
if idx == 3:
self.assertEqual(
row["val_loss"], str(hist.history["val_loss"][0])
)
else:
self.assertEqual(row["val_loss"], "NA")
def test_stop_training_csv(self):
# Test that using the CSVLogger callback with the TerminateOnNaN
# callback does not result in invalid CSVs.
tmpdir = tempfile.TemporaryDirectory()
csv_logfile = os.path.join(tmpdir.name, "csv_logger.csv")
NUM_CLASSES = 2
np.random.seed(1337)
x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM))
y_train = np.random.choice(np.arange(NUM_CLASSES), size=TRAIN_SAMPLES)
x_test = np.random.random((TEST_SAMPLES, INPUT_DIM))
y_test = np.random.choice(np.arange(NUM_CLASSES), size=TEST_SAMPLES)
y_test = numerical_utils.to_categorical(y_test)
y_train = numerical_utils.to_categorical(y_train)
model = Sequential()
initializer = initializers.Constant(value=1e5)
for _ in range(5):
model.add(
layers.Dense(
2,
activation="relu",
kernel_initializer=initializer,
)
)
model.add(layers.Dense(NUM_CLASSES))
model.compile(loss="mean_squared_error", optimizer="sgd")
history = model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
validation_data=(x_test, y_test),
callbacks=[
callbacks.TerminateOnNaN(),
callbacks.CSVLogger(csv_logfile),
],
epochs=20,
)
loss = history.history["loss"]
self.assertEqual(len(loss), 1)
self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0]))
values = []
with open(csv_logfile) as f:
# On Windows, due to \r\n line ends, we may end up reading empty
# lines after each line. Skip empty lines.
values = [x for x in csv.reader(f) if x]
self.assertIn("nan", values[-1], "NaN not logged in CSV Logger.")