keras/keras_core/callbacks/csv_logger.py
Francois Chollet 9bd7a380de Fix import nit
2023-05-05 22:11:37 -07:00

102 lines
3.2 KiB
Python

import collections
import csv
import numpy as np
from tensorflow.io import gfile
from keras_core.api_export import keras_core_export
from keras_core.callbacks.callback import Callback
from keras_core.utils import file_utils
@keras_core_export("keras_core.callbacks.CSVLogger")
class CSVLogger(Callback):
"""Callback that streams epoch results to a CSV file.
Supports all values that can be represented as a string,
including 1D iterables such as `np.ndarray`.
Args:
filename: Filename of the CSV file, e.g. `'run/log.csv'`.
separator: String used to separate elements in the CSV file.
append: Boolean. True: append if file exists (useful for continuing
training). False: overwrite existing file.
Example:
```python
csv_logger = CSVLogger('training.log')
model.fit(X_train, Y_train, callbacks=[csv_logger])
```
"""
def __init__(self, filename, separator=",", append=False):
super().__init__()
self.sep = separator
self.filename = file_utils.path_to_string(filename)
self.append = append
self.writer = None
self.keys = None
self.append_header = True
def on_train_begin(self, logs=None):
if self.append:
if gfile.exists(self.filename):
with gfile.GFile(self.filename, "r") as f:
self.append_header = not bool(len(f.readline()))
mode = "a"
else:
mode = "w"
self.csv_file = gfile.GFile(self.filename, mode)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
def handle_value(k):
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
if isinstance(k, str):
return k
elif (
isinstance(k, collections.abc.Iterable)
and not is_zero_dim_ndarray
):
return f"\"[{', '.join(map(str, k))}]\""
else:
return k
if self.keys is None:
self.keys = sorted(logs.keys())
# When validation_freq > 1, `val_` keys are not in first epoch logs
# Add the `val_` keys so that its part of the fieldnames of writer.
val_keys_found = False
for key in self.keys:
if key.startswith("val_"):
val_keys_found = True
break
if not val_keys_found:
self.keys.extend(["val_" + k for k in self.keys])
if not self.writer:
class CustomDialect(csv.excel):
delimiter = self.sep
fieldnames = ["epoch"] + self.keys
self.writer = csv.DictWriter(
self.csv_file, fieldnames=fieldnames, dialect=CustomDialect
)
if self.append_header:
self.writer.writeheader()
row_dict = collections.OrderedDict({"epoch": epoch})
row_dict.update(
(key, handle_value(logs.get(key, "NA"))) for key in self.keys
)
self.writer.writerow(row_dict)
self.csv_file.flush()
def on_train_end(self, logs=None):
self.csv_file.close()
self.writer = None