102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
import collections
|
|
import csv
|
|
|
|
import numpy as np
|
|
from tensorflow.io import gfile
|
|
|
|
from keras_core.api_export import keras_core_export
|
|
from keras_core.callbacks.callback import Callback
|
|
from keras_core.utils import file_utils
|
|
|
|
|
|
@keras_core_export("keras_core.callbacks.CSVLogger")
|
|
class CSVLogger(Callback):
|
|
"""Callback that streams epoch results to a CSV file.
|
|
|
|
Supports all values that can be represented as a string,
|
|
including 1D iterables such as `np.ndarray`.
|
|
|
|
Args:
|
|
filename: Filename of the CSV file, e.g. `'run/log.csv'`.
|
|
separator: String used to separate elements in the CSV file.
|
|
append: Boolean. True: append if file exists (useful for continuing
|
|
training). False: overwrite existing file.
|
|
|
|
Example:
|
|
|
|
```python
|
|
csv_logger = CSVLogger('training.log')
|
|
model.fit(X_train, Y_train, callbacks=[csv_logger])
|
|
```
|
|
"""
|
|
|
|
def __init__(self, filename, separator=",", append=False):
|
|
super().__init__()
|
|
self.sep = separator
|
|
self.filename = file_utils.path_to_string(filename)
|
|
self.append = append
|
|
self.writer = None
|
|
self.keys = None
|
|
self.append_header = True
|
|
|
|
def on_train_begin(self, logs=None):
|
|
if self.append:
|
|
if gfile.exists(self.filename):
|
|
with gfile.GFile(self.filename, "r") as f:
|
|
self.append_header = not bool(len(f.readline()))
|
|
mode = "a"
|
|
else:
|
|
mode = "w"
|
|
self.csv_file = gfile.GFile(self.filename, mode)
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
logs = logs or {}
|
|
|
|
def handle_value(k):
|
|
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
|
|
if isinstance(k, str):
|
|
return k
|
|
elif (
|
|
isinstance(k, collections.abc.Iterable)
|
|
and not is_zero_dim_ndarray
|
|
):
|
|
return f"\"[{', '.join(map(str, k))}]\""
|
|
else:
|
|
return k
|
|
|
|
if self.keys is None:
|
|
self.keys = sorted(logs.keys())
|
|
# When validation_freq > 1, `val_` keys are not in first epoch logs
|
|
# Add the `val_` keys so that its part of the fieldnames of writer.
|
|
val_keys_found = False
|
|
for key in self.keys:
|
|
if key.startswith("val_"):
|
|
val_keys_found = True
|
|
break
|
|
if not val_keys_found:
|
|
self.keys.extend(["val_" + k for k in self.keys])
|
|
|
|
if not self.writer:
|
|
|
|
class CustomDialect(csv.excel):
|
|
delimiter = self.sep
|
|
|
|
fieldnames = ["epoch"] + self.keys
|
|
|
|
self.writer = csv.DictWriter(
|
|
self.csv_file, fieldnames=fieldnames, dialect=CustomDialect
|
|
)
|
|
if self.append_header:
|
|
self.writer.writeheader()
|
|
|
|
row_dict = collections.OrderedDict({"epoch": epoch})
|
|
row_dict.update(
|
|
(key, handle_value(logs.get(key, "NA"))) for key in self.keys
|
|
)
|
|
self.writer.writerow(row_dict)
|
|
self.csv_file.flush()
|
|
|
|
def on_train_end(self, logs=None):
|
|
self.csv_file.close()
|
|
self.writer = None
|