import collections import csv import numpy as np from tensorflow.io import gfile from keras_core.api_export import keras_core_export from keras_core.callbacks.callback import Callback from keras_core.utils import file_utils @keras_core_export("keras_core.callbacks.CSVLogger") class CSVLogger(Callback): """Callback that streams epoch results to a CSV file. Supports all values that can be represented as a string, including 1D iterables such as `np.ndarray`. Args: filename: Filename of the CSV file, e.g. `'run/log.csv'`. separator: String used to separate elements in the CSV file. append: Boolean. True: append if file exists (useful for continuing training). False: overwrite existing file. Example: ```python csv_logger = CSVLogger('training.log') model.fit(X_train, Y_train, callbacks=[csv_logger]) ``` """ def __init__(self, filename, separator=",", append=False): super().__init__() self.sep = separator self.filename = file_utils.path_to_string(filename) self.append = append self.writer = None self.keys = None self.append_header = True def on_train_begin(self, logs=None): if self.append: if gfile.exists(self.filename): with gfile.GFile(self.filename, "r") as f: self.append_header = not bool(len(f.readline())) mode = "a" else: mode = "w" self.csv_file = gfile.GFile(self.filename, mode) def on_epoch_end(self, epoch, logs=None): logs = logs or {} def handle_value(k): is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 if isinstance(k, str): return k elif ( isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray ): return f"\"[{', '.join(map(str, k))}]\"" else: return k if self.keys is None: self.keys = sorted(logs.keys()) # When validation_freq > 1, `val_` keys are not in first epoch logs # Add the `val_` keys so that its part of the fieldnames of writer. val_keys_found = False for key in self.keys: if key.startswith("val_"): val_keys_found = True break if not val_keys_found: self.keys.extend(["val_" + k for k in self.keys]) if not self.writer: class CustomDialect(csv.excel): delimiter = self.sep fieldnames = ["epoch"] + self.keys self.writer = csv.DictWriter( self.csv_file, fieldnames=fieldnames, dialect=CustomDialect ) if self.append_header: self.writer.writeheader() row_dict = collections.OrderedDict({"epoch": epoch}) row_dict.update( (key, handle_value(logs.get(key, "NA"))) for key in self.keys ) self.writer.writerow(row_dict) self.csv_file.flush() def on_train_end(self, logs=None): self.csv_file.close() self.writer = None