cc89199f1e
* Add CSV Logger and Terminate on Nan * Add CSVLogger and Terminate on Nan tests * Update CSV Logger docstring
21 lines
687 B
Python
21 lines
687 B
Python
import numpy as np
|
|
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.callbacks.callback import Callback
|
|
from keras_core.utils import io_utils
|
|
|
|
|
|
@keras_core_export("keras_core.callbacks.TerminateOnNaN")
|
|
class TerminateOnNaN(Callback):
|
|
"""Callback that terminates training when a NaN loss is encountered."""
|
|
|
|
def on_batch_end(self, batch, logs=None):
|
|
logs = logs or {}
|
|
loss = logs.get("loss")
|
|
if loss is not None:
|
|
if np.isnan(loss) or np.isinf(loss):
|
|
io_utils.print_msg(
|
|
f"Batch {batch}: Invalid loss, terminating training"
|
|
)
|
|
self.model.stop_training = True
|