keras/keras_core/callbacks/terminate_on_nan.py
Ramesh Sampath cc89199f1e Add CSVLogger and TerminateOnNaN Callbacks (#95)
* Add CSV Logger and Terminate on Nan

* Add CSVLogger and Terminate on Nan tests

* Update CSV Logger docstring
2023-05-06 00:09:26 -05:00

21 lines
687 B
Python

import numpy as np
from keras_core.api_export import keras_core_export
from keras_core.callbacks.callback import Callback
from keras_core.utils import io_utils
@keras_core_export("keras_core.callbacks.TerminateOnNaN")
class TerminateOnNaN(Callback):
"""Callback that terminates training when a NaN loss is encountered."""
def on_batch_end(self, batch, logs=None):
logs = logs or {}
loss = logs.get("loss")
if loss is not None:
if np.isnan(loss) or np.isinf(loss):
io_utils.print_msg(
f"Batch {batch}: Invalid loss, terminating training"
)
self.model.stop_training = True