Tensorboard callback (#211)
* Begin adding testing for tensorboard callback * Begin protobuf areas * Sharing WIP PR * Rework init * Add start/stop profiling to backend * Add image/histogram summaries * Model tracing * Full tf working * Enable jax graphs * Fix cifar100 dataset * Finish jax integration checkpoint without tracing * Formatting * All tests passing locally * Model path protected, switch to _model * Remove non-deterministic variable-n tag test * Formatting * Remove profiling jax * Remove tensorboard ops * Remove profiling * File utils fix? * Revert "File utils fix?" This reverts commit a48ba39a9c51c1e39f4509f795fdc2c316a1da3b.
This commit is contained in:
parent
ab98350de8
commit
699e4c3174
@ -11,6 +11,12 @@ from keras_core.trainers.epoch_iterator import EpochIterator
|
||||
|
||||
|
||||
class JAXTrainer(base_trainer.Trainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.train_function = None
|
||||
self.test_function = None
|
||||
self.predict_function = None
|
||||
|
||||
def compute_loss_and_updates(
|
||||
self,
|
||||
trainable_variables,
|
||||
@ -192,14 +198,16 @@ class JAXTrainer(base_trainer.Trainer):
|
||||
else:
|
||||
_train_function = _train_step
|
||||
|
||||
self.train_function = _train_function
|
||||
|
||||
if not self.run_eagerly and self.jit_compile:
|
||||
|
||||
@jax.jit
|
||||
def train_step(state, data):
|
||||
return _train_function(state, data)
|
||||
return self.train_function(state, data)
|
||||
|
||||
else:
|
||||
train_step = _train_function
|
||||
train_step = self.train_function
|
||||
|
||||
self.stop_training = False
|
||||
callbacks.on_train_begin()
|
||||
|
@ -9,4 +9,5 @@ from keras_core.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from keras_core.callbacks.progbar_logger import ProgbarLogger
|
||||
from keras_core.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
|
||||
from keras_core.callbacks.remote_monitor import RemoteMonitor
|
||||
from keras_core.callbacks.tensorboard import TensorBoard
|
||||
from keras_core.callbacks.terminate_on_nan import TerminateOnNaN
|
||||
|
663
keras_core/callbacks/tensorboard.py
Normal file
663
keras_core/callbacks/tensorboard.py
Normal file
@ -0,0 +1,663 @@
|
||||
import warnings
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import tensorflow.summary as summary
|
||||
from tensorflow import nest
|
||||
from tensorflow.compat.v1 import SummaryMetadata
|
||||
from tensorflow.io import gfile
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import operations as ops
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.callbacks.callback import Callback
|
||||
from keras_core.layers import Embedding
|
||||
from keras_core.optimizers import Optimizer
|
||||
from keras_core.optimizers.schedules import learning_rate_schedule
|
||||
from keras_core.utils import file_utils
|
||||
|
||||
|
||||
@keras_core_export("keras_core.callbacks.TensorBoard")
|
||||
class TensorBoard(Callback):
|
||||
|
||||
"""Enable visualizations for TensorBoard.
|
||||
|
||||
TensorBoard is a visualization tool provided with TensorFlow. A TensorFlow
|
||||
installation is required to use this callback.
|
||||
|
||||
This callback logs events for TensorBoard, including:
|
||||
|
||||
* Metrics summary plots
|
||||
* Training graph visualization
|
||||
* Weight histograms
|
||||
* Sampled profiling
|
||||
|
||||
When used in `Model.evaluate` or regular validation
|
||||
in addition to epoch summaries, there will be a summary that records
|
||||
evaluation metrics vs `Model.optimizer.iterations` written. The metric names
|
||||
will be prepended with `evaluation`, with `Model.optimizer.iterations` being
|
||||
the step in the visualized TensorBoard.
|
||||
|
||||
If you have installed TensorFlow with pip, you should be able
|
||||
to launch TensorBoard from the command line:
|
||||
|
||||
```
|
||||
tensorboard --logdir=path_to_your_logs
|
||||
```
|
||||
|
||||
You can find more information about TensorBoard
|
||||
[here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
|
||||
|
||||
Args:
|
||||
log_dir: the path of the directory where to save the log files to be
|
||||
parsed by TensorBoard. e.g. log_dir = os.path.join(working_dir,
|
||||
'logs') This directory should not be reused by any other callbacks.
|
||||
histogram_freq: frequency (in epochs) at which to compute
|
||||
weight histograms for the layers of the model. If set to 0,
|
||||
histograms won't be computed. Validation data (or split) must be
|
||||
specified for histogram visualizations.
|
||||
write_graph: TODO: still not supported.
|
||||
whether to visualize the graph in TensorBoard. The log file
|
||||
can become quite large when write_graph is set to True.
|
||||
write_images: whether to write model weights to visualize as image in
|
||||
TensorBoard.
|
||||
write_steps_per_second: whether to log the training steps per second
|
||||
into TensorBoard. This supports both epoch and batch frequency
|
||||
logging.
|
||||
update_freq: `'batch'` or `'epoch'` or integer. When using `'epoch'`,
|
||||
writes the losses and metrics to TensorBoard after every epoch.
|
||||
If using an integer, let's say `1000`, all metrics and losses
|
||||
(including custom ones added by `Model.compile`) will be logged to
|
||||
TensorBoard every 1000 batches. `'batch'` is a synonym for `1`,
|
||||
meaning that they will be written every batch.
|
||||
Note however that writing too frequently to TensorBoard can slow
|
||||
down your training, especially when used with distribution
|
||||
strategies as it will incur additional synchronization overhead.
|
||||
Batch-level summary writing is also available via `train_step`
|
||||
override. Please see
|
||||
[TensorBoard Scalars tutorial](https://www.tensorflow.org/tensorboard/scalars_and_keras#batch-level_logging) # noqa: E501
|
||||
for more details.
|
||||
profile_batch: TODO: still not supported.
|
||||
Profile the batch(es) to sample compute characteristics.
|
||||
profile_batch must be a non-negative integer or a tuple of integers.
|
||||
A pair of positive integers signify a range of batches to profile.
|
||||
By default, profiling is disabled.
|
||||
embeddings_freq: frequency (in epochs) at which embedding layers will be
|
||||
visualized. If set to 0, embeddings won't be visualized.
|
||||
embeddings_metadata: Dictionary which maps embedding layer names to the
|
||||
filename of a file in which to save metadata for the embedding layer.
|
||||
In case the same metadata file is to be
|
||||
used for all embedding layers, a single filename can be passed.
|
||||
|
||||
Examples:
|
||||
|
||||
Basic usage:
|
||||
|
||||
```python
|
||||
tensorboard_callback = keras_core.callbacks.TensorBoard(log_dir="./logs")
|
||||
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
||||
# Then run the tensorboard command to view the visualizations.
|
||||
```
|
||||
|
||||
Custom batch-level summaries in a subclassed Model:
|
||||
|
||||
```python
|
||||
class MyModel(keras_core.Model):
|
||||
|
||||
def build(self, _):
|
||||
self.dense = keras_core.layers.Dense(10)
|
||||
|
||||
def call(self, x):
|
||||
outputs = self.dense(x)
|
||||
tf.summary.histogram('outputs', outputs)
|
||||
return outputs
|
||||
|
||||
model = MyModel()
|
||||
model.compile('sgd', 'mse')
|
||||
|
||||
# Make sure to set `update_freq=N` to log a batch-level summary every N
|
||||
# batches. In addition to any `tf.summary` contained in `Model.call`,
|
||||
# metrics added in `Model.compile` will be logged every N batches.
|
||||
tb_callback = keras_core.callbacks.TensorBoard('./logs', update_freq=1)
|
||||
model.fit(x_train, y_train, callbacks=[tb_callback])
|
||||
```
|
||||
|
||||
Custom batch-level summaries in a Functional API Model:
|
||||
|
||||
```python
|
||||
def my_summary(x):
|
||||
tf.summary.histogram('x', x)
|
||||
return x
|
||||
|
||||
inputs = keras_core.Input(10)
|
||||
x = keras_core.layers.Dense(10)(inputs)
|
||||
outputs = keras_core.layers.Lambda(my_summary)(x)
|
||||
model = keras_core.Model(inputs, outputs)
|
||||
model.compile('sgd', 'mse')
|
||||
|
||||
# Make sure to set `update_freq=N` to log a batch-level summary every N
|
||||
# batches. In addition to any `tf.summary` contained in `Model.call`,
|
||||
# metrics added in `Model.compile` will be logged every N batches.
|
||||
tb_callback = keras_core.callbacks.TensorBoard('./logs', update_freq=1)
|
||||
model.fit(x_train, y_train, callbacks=[tb_callback])
|
||||
```
|
||||
|
||||
Profiling:
|
||||
|
||||
```python
|
||||
# Profile a single batch, e.g. the 5th batch.
|
||||
tensorboard_callback = keras_core.callbacks.TensorBoard(
|
||||
log_dir='./logs', profile_batch=5)
|
||||
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
||||
|
||||
# Profile a range of batches, e.g. from 10 to 20.
|
||||
tensorboard_callback = keras_core.callbacks.TensorBoard(
|
||||
log_dir='./logs', profile_batch=(10,20))
|
||||
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_dir="logs",
|
||||
histogram_freq=0,
|
||||
write_graph=True,
|
||||
write_images=False,
|
||||
write_steps_per_second=False,
|
||||
update_freq="epoch",
|
||||
profile_batch=0,
|
||||
embeddings_freq=0,
|
||||
embeddings_metadata=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.log_dir = str(log_dir)
|
||||
self.histogram_freq = histogram_freq
|
||||
self.write_graph = write_graph
|
||||
self.write_images = write_images
|
||||
self.write_steps_per_second = write_steps_per_second
|
||||
self.update_freq = 1 if update_freq == "batch" else update_freq
|
||||
self.embeddings_freq = embeddings_freq
|
||||
self.embeddings_metadata = embeddings_metadata
|
||||
self._init_profile_batch(0) # TODO: profiling not available in JAX
|
||||
self._global_train_batch = 0
|
||||
self._previous_epoch_iterations = 0
|
||||
self._train_accumulated_time = 0
|
||||
self._batch_start_time = 0
|
||||
|
||||
# Lazily initialized in order to avoid creating event files when
|
||||
# not needed.
|
||||
self._writers = {}
|
||||
|
||||
# Used to restore any existing `SummaryWriter` after training ends.
|
||||
self._prev_summary_state = []
|
||||
|
||||
def set_model(self, model):
|
||||
"""Sets Keras model and writes graph if specified."""
|
||||
self._model = model
|
||||
self._log_write_dir = self.log_dir
|
||||
|
||||
self._train_dir = os.path.join(self._log_write_dir, "train")
|
||||
self._train_step = 0
|
||||
|
||||
self._val_dir = os.path.join(self._log_write_dir, "validation")
|
||||
self._val_step = 0
|
||||
|
||||
self._writers = {} # Resets writers.
|
||||
|
||||
self._should_write_train_graph = False
|
||||
if self.write_graph:
|
||||
self._write_keras_model_summary()
|
||||
self._should_write_train_graph = True
|
||||
if self.embeddings_freq:
|
||||
self._configure_embeddings()
|
||||
|
||||
@property
|
||||
def _train_writer(self):
|
||||
if "train" not in self._writers:
|
||||
self._writers["train"] = summary.create_file_writer(self._train_dir)
|
||||
return self._writers["train"]
|
||||
|
||||
@property
|
||||
def _val_writer(self):
|
||||
if "val" not in self._writers:
|
||||
self._writers["val"] = summary.create_file_writer(self._val_dir)
|
||||
return self._writers["val"]
|
||||
|
||||
def _write_keras_model_train_graph(self):
|
||||
"""Writes Keras model train_function graph to TensorBoard."""
|
||||
with self._train_writer.as_default():
|
||||
with summary.record_if(True):
|
||||
train_fn = self.model.train_function
|
||||
# If the train_function is a `tf.function`, we can write out a
|
||||
# graph
|
||||
if hasattr(train_fn, "function_spec"):
|
||||
# TODO(b/243822285): Use _variable_creation_fn directly.
|
||||
if hasattr(train_fn, "_concrete_stateful_fn"):
|
||||
summary.graph(train_fn._concrete_stateful_fn.graph)
|
||||
else:
|
||||
summary.graph(
|
||||
train_fn._concrete_variable_creation_fn.graph
|
||||
)
|
||||
|
||||
def _write_keras_model_summary(self):
|
||||
"""Writes Keras graph network summary to TensorBoard."""
|
||||
with self._train_writer.as_default():
|
||||
with summary.record_if(True):
|
||||
if (
|
||||
self.model.__class__.__name__ == "Functional"
|
||||
or self.model.__class__.__name__ == "Sequential"
|
||||
):
|
||||
keras_model_summary("keras", self.model, step=0)
|
||||
|
||||
def _configure_embeddings(self):
|
||||
"""Configure the Projector for embeddings."""
|
||||
from google.protobuf import text_format
|
||||
from tensorboard.plugins import projector
|
||||
|
||||
config = projector.ProjectorConfig()
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, Embedding):
|
||||
embedding = config.embeddings.add()
|
||||
# Embeddings are always the first layer, so this naming should
|
||||
# be consistent in any keras models checkpoints.
|
||||
name = (
|
||||
"layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
)
|
||||
embedding.tensor_name = name
|
||||
|
||||
if self.embeddings_metadata is not None:
|
||||
if isinstance(self.embeddings_metadata, str):
|
||||
embedding.metadata_path = self.embeddings_metadata
|
||||
else:
|
||||
if layer.name in self.embeddings_metadata.keys():
|
||||
embedding.metadata_path = (
|
||||
self.embeddings_metadata.pop(layer.name)
|
||||
)
|
||||
|
||||
if self.embeddings_metadata and not isinstance(
|
||||
self.embeddings_metadata, str
|
||||
):
|
||||
raise ValueError(
|
||||
"Unrecognized `Embedding` layer names passed to "
|
||||
"`keras_core.callbacks.TensorBoard` `embeddings_metadata` "
|
||||
f"argument: {self.embeddings_metadata.keys()}"
|
||||
)
|
||||
|
||||
config_pbtxt = text_format.MessageToString(config)
|
||||
path = os.path.join(self._log_write_dir, "projector_config.pbtxt")
|
||||
with gfile.GFile(path, "w") as f:
|
||||
f.write(config_pbtxt)
|
||||
|
||||
def _push_writer(self, writer, step):
|
||||
"""Sets the default writer for custom batch-level summaries."""
|
||||
if self.update_freq == "epoch":
|
||||
return
|
||||
|
||||
def should_record():
|
||||
return step % self.update_freq == 0
|
||||
|
||||
summary_context = (
|
||||
writer.as_default(step),
|
||||
summary.record_if(should_record),
|
||||
)
|
||||
self._prev_summary_state.append(summary_context)
|
||||
summary_context[0].__enter__()
|
||||
summary_context[1].__enter__()
|
||||
|
||||
def _pop_writer(self):
|
||||
"""Pops the current writer."""
|
||||
if self.update_freq == "epoch":
|
||||
return
|
||||
|
||||
# See _push_writer for the content of the previous_context, which is
|
||||
# pair of context.
|
||||
previous_context = self._prev_summary_state.pop()
|
||||
previous_context[1].__exit__(*sys.exc_info())
|
||||
previous_context[0].__exit__(*sys.exc_info())
|
||||
|
||||
def _close_writers(self):
|
||||
for writer in self._writers.values():
|
||||
writer.close()
|
||||
|
||||
def _init_profile_batch(self, profile_batch):
|
||||
"""Validate profile_batch value and set the range of batches to profile.
|
||||
|
||||
Sets values of _start_batch and _stop_batch attributes,
|
||||
specifying the start and stop batch to profile.
|
||||
Setting `profile_batch=0` disables profiling.
|
||||
|
||||
Args:
|
||||
profile_batch: The range of batches to profile. Should be a
|
||||
non-negative integer or a comma separated string of pair of positive
|
||||
integers. A pair of positive integers signify a range of batches to
|
||||
profile.
|
||||
|
||||
Raises:
|
||||
ValueError: If profile_batch is not an integer or a comma separated
|
||||
pair of positive integers.
|
||||
|
||||
"""
|
||||
profile_batch_error_message = (
|
||||
"profile_batch must be a non-negative integer or "
|
||||
"2-tuple of positive "
|
||||
"integers. A pair of positive integers "
|
||||
"signifies a range of batches "
|
||||
f"to profile. Found: {profile_batch}"
|
||||
)
|
||||
|
||||
# Support legacy way of specifying "start,stop" or "start" as str.
|
||||
if isinstance(profile_batch, str):
|
||||
profile_batch = str(profile_batch).split(",")
|
||||
profile_batch = nest.map_structure(int, profile_batch)
|
||||
|
||||
if isinstance(profile_batch, int):
|
||||
self._start_batch = profile_batch
|
||||
self._stop_batch = profile_batch
|
||||
elif (
|
||||
isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2
|
||||
):
|
||||
self._start_batch, self._stop_batch = profile_batch
|
||||
else:
|
||||
raise ValueError(profile_batch_error_message)
|
||||
|
||||
if self._start_batch < 0 or self._stop_batch < self._start_batch:
|
||||
raise ValueError(profile_batch_error_message)
|
||||
|
||||
# True when the profiler was successfully started by this callback.
|
||||
# We track the status here to make sure callbacks do not interfere with
|
||||
# each other. The callback will only stop the profiler it started.
|
||||
self._profiler_started = False
|
||||
if self._start_batch > 0:
|
||||
# Warm up and improve the profiling accuracy.
|
||||
self._start_profiler(logdir="")
|
||||
self._stop_profiler(save=False)
|
||||
# True when a trace is running.
|
||||
self._is_tracing = False
|
||||
|
||||
# Setting `profile_batch=0` disables profiling.
|
||||
self._should_trace = not (
|
||||
self._start_batch == 0 and self._stop_batch == 0
|
||||
)
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
self._global_train_batch = 0
|
||||
self._previous_epoch_iterations = 0
|
||||
self._push_writer(self._train_writer, self._train_step)
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
self._pop_writer()
|
||||
|
||||
if self._is_tracing:
|
||||
self._stop_trace()
|
||||
|
||||
self._close_writers()
|
||||
|
||||
def on_test_begin(self, logs=None):
|
||||
self._push_writer(self._val_writer, self._val_step)
|
||||
|
||||
def on_test_end(self, logs=None):
|
||||
if self.model.optimizer and hasattr(self.model.optimizer, "iterations"):
|
||||
with summary.record_if(True), self._val_writer.as_default():
|
||||
for name, value in logs.items():
|
||||
summary.scalar(
|
||||
"evaluation_" + name + "_vs_iterations",
|
||||
value,
|
||||
step=self.model.optimizer.iterations,
|
||||
)
|
||||
self._pop_writer()
|
||||
|
||||
def _implements_train_batch_hooks(self):
|
||||
# Only call batch hooks when tracing or write_steps_per_second are
|
||||
# enabled
|
||||
return self._should_trace or self.write_steps_per_second
|
||||
|
||||
def on_train_batch_begin(self, batch, logs=None):
|
||||
self._global_train_batch += 1
|
||||
if self.write_steps_per_second:
|
||||
self._batch_start_time = time.time()
|
||||
if not self._should_trace:
|
||||
return
|
||||
|
||||
if self._global_train_batch == self._start_batch:
|
||||
self._start_trace()
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
if self._should_write_train_graph:
|
||||
self._write_keras_model_train_graph()
|
||||
self._should_write_train_graph = False
|
||||
if self.write_steps_per_second:
|
||||
batch_run_time = time.time() - self._batch_start_time
|
||||
summary.scalar(
|
||||
"batch_steps_per_second",
|
||||
1.0 / batch_run_time,
|
||||
step=self._train_step,
|
||||
)
|
||||
|
||||
# `logs` isn't necessarily always a dict
|
||||
if isinstance(logs, dict):
|
||||
for name, value in logs.items():
|
||||
summary.scalar("batch_" + name, value, step=self._train_step)
|
||||
|
||||
if not self._should_trace:
|
||||
return
|
||||
|
||||
if self._is_tracing and self._global_train_batch >= self._stop_batch:
|
||||
self._stop_trace()
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
# Keeps track of epoch for profiling.
|
||||
if self.write_steps_per_second:
|
||||
self._previous_epoch_iterations = self.model.optimizer.iterations
|
||||
self._epoch_start_time = time.time()
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
"""Runs metrics and histogram summaries at epoch end."""
|
||||
self._log_epoch_metrics(epoch, logs)
|
||||
|
||||
if self.histogram_freq and epoch % self.histogram_freq == 0:
|
||||
self._log_weights(epoch)
|
||||
|
||||
if self.embeddings_freq and epoch % self.embeddings_freq == 0:
|
||||
self._log_embeddings(epoch)
|
||||
|
||||
def _start_trace(self):
|
||||
summary.trace_on(graph=True, profiler=False)
|
||||
self._start_profiler(logdir=self.log_dir)
|
||||
self._is_tracing = True
|
||||
|
||||
def _stop_trace(self, batch=None):
|
||||
"""Logs the trace graph to TensorBoard."""
|
||||
if batch is None:
|
||||
batch = self._stop_batch
|
||||
with self._train_writer.as_default():
|
||||
with summary.record_if(True):
|
||||
# TODO(b/126388999): Remove step info in the summary name.
|
||||
summary.trace_export(name="batch_%d" % batch, step=batch)
|
||||
self._stop_profiler()
|
||||
self._is_tracing = False
|
||||
|
||||
def _collect_learning_rate(self, logs):
|
||||
if isinstance(self.model.optimizer, Optimizer):
|
||||
lr_schedule = getattr(self.model.optimizer, "_learning_rate", None)
|
||||
else:
|
||||
lr_schedule = getattr(self.model.optimizer, "lr", None)
|
||||
if isinstance(lr_schedule, learning_rate_schedule.LearningRateSchedule):
|
||||
logs["learning_rate"] = lr_schedule(self.model.optimizer.iterations)
|
||||
return logs
|
||||
|
||||
def _compute_steps_per_second(self):
|
||||
current_iteration = self.model.optimizer.iterations
|
||||
time_since_epoch_begin = time.time() - self._epoch_start_time
|
||||
current_iteration = ops.convert_to_tensor(current_iteration, "float32")
|
||||
self._previous_epoch_iterations = ops.convert_to_tensor(
|
||||
self._previous_epoch_iterations, "float32"
|
||||
)
|
||||
time_since_epoch_begin = ops.convert_to_tensor(
|
||||
time_since_epoch_begin, "float32"
|
||||
)
|
||||
|
||||
steps_per_second = (
|
||||
current_iteration - self._previous_epoch_iterations
|
||||
) / time_since_epoch_begin
|
||||
return steps_per_second
|
||||
|
||||
def _log_epoch_metrics(self, epoch, logs):
|
||||
"""Writes epoch metrics out as scalar summaries.
|
||||
|
||||
Args:
|
||||
epoch: Int. The global step to use for TensorBoard.
|
||||
logs: Dict. Keys are scalar summary names, values are scalars.
|
||||
"""
|
||||
if not logs:
|
||||
return
|
||||
|
||||
train_logs = {k: v for k, v in logs.items() if not k.startswith("val_")}
|
||||
val_logs = {k: v for k, v in logs.items() if k.startswith("val_")}
|
||||
train_logs = self._collect_learning_rate(train_logs)
|
||||
if self.write_steps_per_second:
|
||||
train_logs["steps_per_second"] = self._compute_steps_per_second()
|
||||
|
||||
with summary.record_if(True):
|
||||
if train_logs:
|
||||
with self._train_writer.as_default():
|
||||
for name, value in train_logs.items():
|
||||
summary.scalar("epoch_" + name, value, step=epoch)
|
||||
if val_logs:
|
||||
with self._val_writer.as_default():
|
||||
for name, value in val_logs.items():
|
||||
name = name[4:] # Remove 'val_' prefix.
|
||||
summary.scalar("epoch_" + name, value, step=epoch)
|
||||
|
||||
def _log_weights(self, epoch):
|
||||
"""Logs the weights of the Model to TensorBoard."""
|
||||
with self._train_writer.as_default():
|
||||
with summary.record_if(True):
|
||||
for layer in self.model.layers:
|
||||
for weight in layer.weights:
|
||||
weight_name = weight.name.replace(":", "_")
|
||||
# Add a suffix to prevent summary tag name collision.
|
||||
histogram_weight_name = weight_name + "/histogram"
|
||||
summary.histogram(
|
||||
histogram_weight_name, weight, step=epoch
|
||||
)
|
||||
if self.write_images:
|
||||
# Add a suffix to prevent summary tag name
|
||||
# collision.
|
||||
image_weight_name = weight_name + "/image"
|
||||
self._log_weight_as_image(
|
||||
weight, image_weight_name, epoch
|
||||
)
|
||||
self._train_writer.flush()
|
||||
|
||||
def _log_weight_as_image(self, weight, weight_name, epoch):
|
||||
"""Logs a weight as a TensorBoard image."""
|
||||
w_img = ops.squeeze(weight)
|
||||
shape = w_img.shape
|
||||
if len(shape) == 1: # Bias case
|
||||
w_img = ops.reshape(w_img, [1, shape[0], 1, 1])
|
||||
elif len(shape) == 2: # Dense layer kernel case
|
||||
if shape[0] > shape[1]:
|
||||
w_img = ops.transpose(w_img)
|
||||
shape = w_img.shape
|
||||
w_img = ops.reshape(w_img, [1, shape[0], shape[1], 1])
|
||||
elif len(shape) == 3: # ConvNet case
|
||||
if backend.image_data_format() == "channels_last":
|
||||
# Switch to channels_first to display every kernel as a separate
|
||||
# image.
|
||||
w_img = ops.transpose(w_img, [2, 0, 1])
|
||||
shape = w_img.shape
|
||||
w_img = ops.reshape(w_img, [shape[0], shape[1], shape[2], 1])
|
||||
|
||||
shape = w_img.shape
|
||||
# Not possible to handle 3D convnets etc.
|
||||
if len(shape) == 4 and shape[-1] in [1, 3, 4]:
|
||||
summary.image(weight_name, w_img, step=epoch)
|
||||
|
||||
def _log_embeddings(self, epoch):
|
||||
embeddings_ckpt = os.path.join(
|
||||
self._log_write_dir,
|
||||
"train",
|
||||
f"keras_embedding.ckpt-{epoch}.weights.h5",
|
||||
)
|
||||
self.model.save_weights(embeddings_ckpt)
|
||||
|
||||
def _start_profiler(self, logdir):
|
||||
"""Starts the profiler if currently inactive.
|
||||
|
||||
Args:
|
||||
logdir: Directory where profiler results will be saved.
|
||||
"""
|
||||
if self._profiler_started:
|
||||
return
|
||||
try:
|
||||
backend.tensorboard.start_trace(logdir)
|
||||
self._profiler_started = True
|
||||
except Exception as e:
|
||||
# Profiler errors should not be fatal.
|
||||
logging.error("Failed to start profiler: %s", e)
|
||||
|
||||
def _stop_profiler(self, save=True):
|
||||
"""Stops the profiler if currently active.
|
||||
|
||||
Args:
|
||||
save: Whether to save the profiler results to TensorBoard.
|
||||
"""
|
||||
if not self._profiler_started:
|
||||
return
|
||||
try:
|
||||
backend.tensorboard.stop_trace(save=save)
|
||||
except Exception as e:
|
||||
# Profiler errors should not be fatal.
|
||||
logging.error("Failed to stop profiler: %s", e)
|
||||
finally:
|
||||
self._profiler_started = False
|
||||
|
||||
|
||||
def keras_model_summary(name, data, step=None):
|
||||
"""Writes a Keras model as JSON to as a Summary.
|
||||
|
||||
Writing the Keras model configuration allows the TensorBoard graph plugin to
|
||||
render a conceptual graph, as opposed to graph of ops. In case the model
|
||||
fails to serialize as JSON, it ignores and returns False.
|
||||
|
||||
Args:
|
||||
name: A name for this summary. The summary tag used for TensorBoard will
|
||||
be this name prefixed by any active name scopes.
|
||||
data: A Keras Model to write.
|
||||
step: Explicit `int64`-castable monotonic step value for this summary. If
|
||||
omitted, this defaults to `tf.summary.experimental.get_step()`, which
|
||||
must not be None.
|
||||
|
||||
Returns:
|
||||
True on success, or False if no summary was written because no default
|
||||
summary writer was available.
|
||||
|
||||
Raises:
|
||||
ValueError: if a default writer exists, but no step was provided and
|
||||
`tf.summary.experimental.get_step()` is None.
|
||||
"""
|
||||
summary_metadata = SummaryMetadata()
|
||||
# Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for
|
||||
# the rationale.
|
||||
summary_metadata.plugin_data.plugin_name = "graph_keras_model"
|
||||
# version number = 1
|
||||
summary_metadata.plugin_data.content = b"1"
|
||||
|
||||
try:
|
||||
json_string = data.to_json()
|
||||
except Exception as exc:
|
||||
# An exception should not break a model code.
|
||||
warnings.warn(f"Model failed to serialize as JSON. Ignoring... {exc}")
|
||||
return False
|
||||
|
||||
with summary.experimental.summary_scope(
|
||||
name, "graph_keras_model", [data, step]
|
||||
) as (tag, _):
|
||||
tensor = ops.convert_to_tensor(json_string, dtype="string")
|
||||
return summary.write(
|
||||
tag=tag, tensor=tensor, step=step, metadata=summary_metadata
|
||||
)
|
791
keras_core/callbacks/tensorboard_test.py
Normal file
791
keras_core/callbacks/tensorboard_test.py
Normal file
@ -0,0 +1,791 @@
|
||||
import collections
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.summary as summary
|
||||
from tensorflow.compat.v1 import SummaryMetadata
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python.lib.io import tf_record
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import callbacks
|
||||
from keras_core import losses
|
||||
from keras_core import models
|
||||
from keras_core import operations as ops
|
||||
from keras_core import optimizers
|
||||
from keras_core import testing
|
||||
from keras_core.layers import Conv2D
|
||||
from keras_core.layers import Dense
|
||||
from keras_core.layers import Embedding
|
||||
from keras_core.layers import Flatten
|
||||
from keras_core.layers import Input
|
||||
from keras_core.layers import Layer
|
||||
from keras_core.optimizers import schedules
|
||||
|
||||
# Note: this file and tensorboard in general has a dependency on tensorflow
|
||||
|
||||
# A summary that was emitted during a test. Fields:
|
||||
# logdir: str. The logdir of the FileWriter to which the summary was
|
||||
# written.
|
||||
# tag: str. The name of the summary.
|
||||
_ObservedSummary = collections.namedtuple("_ObservedSummary", ("logdir", "tag"))
|
||||
|
||||
|
||||
class _SummaryIterator(object):
|
||||
"""Yields `Event` protocol buffers from a given path."""
|
||||
|
||||
def __init__(self, path):
|
||||
self._tf_record_iterator = tf_record.tf_record_iterator(path)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
r = next(self._tf_record_iterator)
|
||||
return event_pb2.Event.FromString(r)
|
||||
|
||||
next = __next__
|
||||
|
||||
|
||||
class _SummaryFile:
|
||||
"""A record of summary tags and the files to which they were written.
|
||||
|
||||
Fields `scalars`, `images`, `histograms`, and `tensors` are sets
|
||||
containing `_ObservedSummary` values.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.scalars = set()
|
||||
self.images = set()
|
||||
self.histograms = set()
|
||||
self.tensors = set()
|
||||
self.graph_defs = []
|
||||
self.convert_from_v2_summary_proto = False
|
||||
|
||||
|
||||
def get_model_from_layers(model_layers, input_shape, name=None):
|
||||
model = models.Sequential(name=name)
|
||||
model.add(
|
||||
Input(
|
||||
input_shape,
|
||||
dtype="float32",
|
||||
)
|
||||
)
|
||||
for layer in model_layers:
|
||||
model.add(layer)
|
||||
return model
|
||||
|
||||
|
||||
def list_summaries(logdir):
|
||||
"""Read all summaries under the logdir into a `_SummaryFile`.
|
||||
|
||||
Args:
|
||||
logdir: A path to a directory that contains zero or more event
|
||||
files, either as direct children or in transitive subdirectories.
|
||||
Summaries in these events must only contain old-style scalars,
|
||||
images, and histograms. Non-summary events, like `graph_def`s, are
|
||||
ignored.
|
||||
|
||||
Returns:
|
||||
A `_SummaryFile` object reflecting all summaries written to any
|
||||
event files in the logdir or any of its descendant directories.
|
||||
|
||||
Raises:
|
||||
ValueError: If an event file contains an summary of unexpected kind.
|
||||
"""
|
||||
result = _SummaryFile()
|
||||
for dirpath, _, filenames in os.walk(logdir):
|
||||
for filename in filenames:
|
||||
if not filename.startswith("events.out."):
|
||||
continue
|
||||
path = os.path.join(dirpath, filename)
|
||||
for event in _SummaryIterator(path):
|
||||
if event.graph_def:
|
||||
result.graph_defs.append(event.graph_def)
|
||||
if not event.summary: # (e.g., it's a `graph_def` event)
|
||||
continue
|
||||
for value in event.summary.value:
|
||||
tag = value.tag
|
||||
# Case on the `value` rather than the summary metadata
|
||||
# because the Keras callback uses `summary_ops_v2` to emit
|
||||
# old-style summaries. See b/124535134.
|
||||
kind = value.WhichOneof("value")
|
||||
container = {
|
||||
"simple_value": result.scalars,
|
||||
"image": result.images,
|
||||
"histo": result.histograms,
|
||||
"tensor": result.tensors,
|
||||
}.get(kind)
|
||||
if container is None:
|
||||
raise ValueError(
|
||||
"Unexpected summary kind %r in event file %s:\n%r"
|
||||
% (kind, path, event)
|
||||
)
|
||||
elif kind == "tensor" and tag != "keras":
|
||||
# Convert the tf2 summary proto to old style for type
|
||||
# checking.
|
||||
plugin_name = value.metadata.plugin_data.plugin_name
|
||||
container = {
|
||||
"images": result.images,
|
||||
"histograms": result.histograms,
|
||||
"scalars": result.scalars,
|
||||
}.get(plugin_name)
|
||||
if container is not None:
|
||||
result.convert_from_v2_summary_proto = True
|
||||
else:
|
||||
container = result.tensors
|
||||
container.add(_ObservedSummary(logdir=dirpath, tag=tag))
|
||||
return result
|
||||
|
||||
|
||||
class TestTensorBoardV2(testing.TestCase):
|
||||
def setUp(self):
|
||||
super(TestTensorBoardV2, self).setUp()
|
||||
self.logdir = os.path.join(self.get_temp_dir(), "tb")
|
||||
self.train_dir = os.path.join(self.logdir, "train")
|
||||
self.validation_dir = os.path.join(self.logdir, "validation")
|
||||
|
||||
def _get_model(self, compile_model=True):
|
||||
layers = [
|
||||
Conv2D(8, (3, 3)),
|
||||
Flatten(),
|
||||
Dense(1),
|
||||
]
|
||||
model = get_model_from_layers(layers, input_shape=(10, 10, 1))
|
||||
|
||||
if compile_model:
|
||||
opt = optimizers.SGD(learning_rate=0.001)
|
||||
model.compile(opt, "mse")
|
||||
return model
|
||||
|
||||
def test_TensorBoard_default_logdir(self):
|
||||
"""Regression test for cross-platform pathsep in default logdir."""
|
||||
os.chdir(self.get_temp_dir())
|
||||
|
||||
model = self._get_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = callbacks.TensorBoard() # no logdir specified
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=2,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk],
|
||||
)
|
||||
|
||||
summary_file = list_summaries(logdir=".")
|
||||
train_dir = os.path.join(".", "logs", "train")
|
||||
validation_dir = os.path.join(".", "logs", "validation")
|
||||
self.assertEqual(
|
||||
summary_file.scalars,
|
||||
{
|
||||
_ObservedSummary(logdir=train_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(logdir=validation_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(
|
||||
logdir=validation_dir, tag="evaluation_loss_vs_iterations"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def test_TensorBoard_basic(self):
|
||||
model = self._get_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = callbacks.TensorBoard(self.logdir)
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=2,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk],
|
||||
)
|
||||
|
||||
summary_file = list_summaries(self.logdir)
|
||||
self.assertEqual(
|
||||
summary_file.scalars,
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(logdir=self.validation_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(
|
||||
logdir=self.validation_dir,
|
||||
tag="evaluation_loss_vs_iterations",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def test_TensorBoard_across_invocations(self):
|
||||
"""Regression test for summary writer resource use-after-free.
|
||||
|
||||
See: <https://github.com/tensorflow/tensorflow/issues/25707>
|
||||
"""
|
||||
model = self._get_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = callbacks.TensorBoard(self.logdir)
|
||||
|
||||
for _ in (1, 2):
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=2,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk],
|
||||
)
|
||||
|
||||
summary_file = list_summaries(self.logdir)
|
||||
self.assertEqual(
|
||||
summary_file.scalars,
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(logdir=self.validation_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(
|
||||
logdir=self.validation_dir,
|
||||
tag="evaluation_loss_vs_iterations",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def test_TensorBoard_no_spurious_event_files(self):
|
||||
model = self._get_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = callbacks.TensorBoard(self.logdir)
|
||||
|
||||
model.fit(x, y, batch_size=2, epochs=2, callbacks=[tb_cbk])
|
||||
|
||||
events_file_run_basenames = set()
|
||||
for dirpath, _, filenames in os.walk(self.train_dir):
|
||||
if any(fn.startswith("events.out.") for fn in filenames):
|
||||
events_file_run_basenames.add(os.path.basename(dirpath))
|
||||
self.assertEqual(events_file_run_basenames, {"train"})
|
||||
|
||||
def test_TensorBoard_batch_metrics(self):
|
||||
model = self._get_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = callbacks.TensorBoard(self.logdir, update_freq=1)
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=2,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk],
|
||||
)
|
||||
|
||||
summary_file = list_summaries(self.logdir)
|
||||
self.assertEqual(
|
||||
summary_file.scalars,
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag="batch_loss"),
|
||||
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(logdir=self.validation_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(
|
||||
logdir=self.validation_dir,
|
||||
tag="evaluation_loss_vs_iterations",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def test_TensorBoard_learning_rate_schedules(self):
|
||||
model = self._get_model(compile_model=False)
|
||||
opt = optimizers.SGD(schedules.CosineDecay(0.01, 1))
|
||||
model.compile(opt, "mse")
|
||||
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=2,
|
||||
callbacks=[callbacks.TensorBoard(self.logdir)],
|
||||
)
|
||||
|
||||
summary_file = list_summaries(self.logdir)
|
||||
self.assertEqual(
|
||||
summary_file.scalars,
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(
|
||||
logdir=self.train_dir, tag="epoch_learning_rate"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def test_TensorBoard_global_step(self):
|
||||
model = self._get_model(compile_model=False)
|
||||
opt = optimizers.SGD(schedules.CosineDecay(0.01, 1))
|
||||
model.compile(opt, "mse")
|
||||
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=2,
|
||||
verbose=0,
|
||||
callbacks=[
|
||||
callbacks.TensorBoard(
|
||||
self.logdir,
|
||||
update_freq=1,
|
||||
profile_batch=0,
|
||||
write_steps_per_second=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
summary_file = list_summaries(self.logdir)
|
||||
self.assertEqual(
|
||||
summary_file.scalars,
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag="batch_loss"),
|
||||
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(
|
||||
logdir=self.train_dir, tag="epoch_learning_rate"
|
||||
),
|
||||
_ObservedSummary(
|
||||
logdir=self.train_dir, tag="epoch_steps_per_second"
|
||||
),
|
||||
_ObservedSummary(
|
||||
logdir=self.train_dir, tag="batch_steps_per_second"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def test_TensorBoard_weight_histograms(self):
|
||||
model = self._get_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = callbacks.TensorBoard(self.logdir, histogram_freq=1)
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=2,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk],
|
||||
)
|
||||
summary_file = list_summaries(self.logdir)
|
||||
|
||||
self.assertEqual(
|
||||
summary_file.scalars,
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(logdir=self.validation_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(
|
||||
logdir=self.validation_dir,
|
||||
tag="evaluation_loss_vs_iterations",
|
||||
),
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
self._strip_layer_names(summary_file.histograms, "sequential"),
|
||||
{_ObservedSummary(logdir=self.train_dir, tag="histogram")},
|
||||
)
|
||||
|
||||
def test_TensorBoard_weight_images(self):
|
||||
model = self._get_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = callbacks.TensorBoard(
|
||||
self.logdir, histogram_freq=1, write_images=True
|
||||
)
|
||||
model_type = "sequential"
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=2,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk],
|
||||
)
|
||||
summary_file = list_summaries(self.logdir)
|
||||
|
||||
self.assertEqual(
|
||||
summary_file.scalars,
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(logdir=self.validation_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(
|
||||
logdir=self.validation_dir,
|
||||
tag="evaluation_loss_vs_iterations",
|
||||
),
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
self._strip_layer_names(summary_file.histograms, model_type),
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag="histogram"),
|
||||
},
|
||||
)
|
||||
expected_image_summaries = {
|
||||
_ObservedSummary(logdir=self.train_dir, tag="image"),
|
||||
_ObservedSummary(logdir=self.train_dir, tag="bias/image"),
|
||||
_ObservedSummary(logdir=self.train_dir, tag="kernel/image"),
|
||||
}
|
||||
self.assertEqual(
|
||||
self._strip_variable_names(summary_file.images),
|
||||
expected_image_summaries,
|
||||
)
|
||||
|
||||
def test_TensorBoard_projector_callback(self):
|
||||
layers = [
|
||||
Embedding(10, 10, name="test_embedding"),
|
||||
Dense(10, activation="relu"),
|
||||
Dense(1, activation="sigmoid"),
|
||||
]
|
||||
model = get_model_from_layers(layers, input_shape=(10,))
|
||||
model.compile(
|
||||
optimizer="adam", loss=losses.BinaryCrossentropy(from_logits=True)
|
||||
)
|
||||
x, y = np.ones((10, 10)), np.ones((10, 10))
|
||||
tb_cbk = callbacks.TensorBoard(
|
||||
self.logdir,
|
||||
embeddings_freq=1,
|
||||
embeddings_metadata={"test_embedding": "metadata.tsv"},
|
||||
)
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=2,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk],
|
||||
)
|
||||
|
||||
with open(os.path.join(self.logdir, "projector_config.pbtxt")) as f:
|
||||
self.assertEqual(
|
||||
f.readlines(),
|
||||
[
|
||||
"embeddings {\n",
|
||||
" tensor_name: "
|
||||
'"layer_with_weights-0/embeddings/.ATTRIBUTES/'
|
||||
'VARIABLE_VALUE"\n',
|
||||
' metadata_path: "metadata.tsv"\n',
|
||||
"}\n",
|
||||
],
|
||||
)
|
||||
|
||||
def test_custom_summary(self):
|
||||
def scalar_v2_mock(name, data, step=None):
|
||||
"""A reimplementation of the scalar plugin to avoid circular
|
||||
deps."""
|
||||
metadata = SummaryMetadata()
|
||||
# Should match value in tensorboard/plugins/scalar/metadata.py.
|
||||
metadata.plugin_data.plugin_name = "scalars"
|
||||
with summary.experimental.summary_scope(
|
||||
name, "scalar_summary", values=[data, step]
|
||||
) as (tag, _):
|
||||
tensor = ops.convert_to_tensor(data, "float32")
|
||||
summary.write(
|
||||
tag=tag,
|
||||
tensor=tensor,
|
||||
step=step,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
class LayerWithSummary(Layer):
|
||||
def call(self, x):
|
||||
scalar_v2_mock("custom_summary", ops.sum(x))
|
||||
return x
|
||||
|
||||
model = get_model_from_layers(
|
||||
[LayerWithSummary()], input_shape=(5,), name="model"
|
||||
)
|
||||
|
||||
model.compile("sgd", "mse", jit_compile=False) # summary ops can't xla
|
||||
tb_cbk = callbacks.TensorBoard(self.logdir, update_freq=1)
|
||||
x, y = np.ones((10, 5)), np.ones((10, 5))
|
||||
model.fit(
|
||||
x, y, batch_size=2, validation_data=(x, y), callbacks=[tb_cbk]
|
||||
)
|
||||
summary_file = list_summaries(self.logdir)
|
||||
# TODO: tensorflow will tag with model/layer_with_summary/custom_summary
|
||||
# Jax will only use custom_summary tag
|
||||
self.assertEqual(
|
||||
self._strip_to_only_final_name(summary_file.scalars),
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag="batch_loss"),
|
||||
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(logdir=self.validation_dir, tag="epoch_loss"),
|
||||
_ObservedSummary(
|
||||
logdir=self.validation_dir,
|
||||
tag="evaluation_loss_vs_iterations",
|
||||
),
|
||||
_ObservedSummary(
|
||||
logdir=self.train_dir,
|
||||
tag="custom_summary",
|
||||
),
|
||||
_ObservedSummary(
|
||||
logdir=self.validation_dir,
|
||||
tag="custom_summary",
|
||||
),
|
||||
},
|
||||
)
|
||||
# self.assertEqual(
|
||||
# summary_file.scalars,
|
||||
# {
|
||||
# _ObservedSummary(logdir=self.train_dir, tag="batch_loss"),
|
||||
# _ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
|
||||
# _ObservedSummary(logdir=self.validation_dir,
|
||||
# tag="epoch_loss"),
|
||||
# _ObservedSummary(
|
||||
# logdir=self.validation_dir,
|
||||
# tag="evaluation_loss_vs_iterations",
|
||||
# ),
|
||||
# _ObservedSummary(
|
||||
# logdir=self.train_dir,
|
||||
# tag="model/layer_with_summary/custom_summary",
|
||||
# ),
|
||||
# _ObservedSummary(
|
||||
# logdir=self.validation_dir,
|
||||
# tag="model/layer_with_summary/custom_summary",
|
||||
# ),
|
||||
# },
|
||||
# )
|
||||
|
||||
def _strip_to_only_final_name(self, summaries):
|
||||
"""Removes all leading names in a summary
|
||||
|
||||
Args:
|
||||
summaries: A `set` of `_ObservedSummary` values.
|
||||
|
||||
Returns:
|
||||
A new `set` of `_ObservedSummary` values striped of all
|
||||
name except for the terminal one.
|
||||
|
||||
"""
|
||||
result = set()
|
||||
for s in summaries:
|
||||
if "/" not in s.tag:
|
||||
result.add(s)
|
||||
else:
|
||||
new_tag = s.tag.split("/")[-1]
|
||||
result.add(s._replace(tag=new_tag))
|
||||
return result
|
||||
|
||||
def _strip_layer_names(self, summaries, model_type):
|
||||
"""Deduplicate summary names modulo layer prefix.
|
||||
|
||||
This removes the first slash-component of each tag name: for
|
||||
instance, "foo/bar/baz" becomes "bar/baz".
|
||||
|
||||
Args:
|
||||
summaries: A `set` of `_ObservedSummary` values.
|
||||
model_type: The model type currently being tested.
|
||||
|
||||
Returns:
|
||||
A new `set` of `_ObservedSummary` values with layer prefixes
|
||||
removed.
|
||||
"""
|
||||
result = set()
|
||||
for s in summaries:
|
||||
if "/" not in s.tag:
|
||||
raise ValueError(f"tag has no layer name: {s.tag!r}")
|
||||
start_from = 2 if "subclass" in model_type else 1
|
||||
new_tag = "/".join(s.tag.split("/")[start_from:])
|
||||
result.add(s._replace(tag=new_tag))
|
||||
return result
|
||||
|
||||
def _strip_variable_names(self, summaries):
|
||||
"""Remove `variable_n` from summary tag
|
||||
|
||||
`variable_n` tag names are added with random numbers. Removing them
|
||||
ensures deterministic tag names.
|
||||
|
||||
Args:
|
||||
summaries: A `set` of `_ObservedSummary` values.
|
||||
|
||||
Returns:
|
||||
A new `set` of `_ObservedSummary` values with layer prefixes
|
||||
removed.
|
||||
"""
|
||||
result = set()
|
||||
for s in summaries:
|
||||
if "/" not in s.tag:
|
||||
result.add(s)
|
||||
else:
|
||||
split_tag = s.tag.split("/")
|
||||
if "variable" in split_tag[0]:
|
||||
result.add(s._replace(tag=split_tag[-1]))
|
||||
else:
|
||||
result.add(s)
|
||||
return result
|
||||
|
||||
def test_TensorBoard_non_blocking(self):
|
||||
model = models.Sequential([Dense(1)])
|
||||
model.optimizer = optimizers.Adam()
|
||||
tb = callbacks.TensorBoard(self.logdir)
|
||||
cb_list = callbacks.CallbackList(
|
||||
[tb], model=model, epochs=1, steps=100, verbose=0
|
||||
)
|
||||
|
||||
tensor = ops.convert_to_tensor(1.0)
|
||||
|
||||
def mock_numpy():
|
||||
raise RuntimeError(
|
||||
"If this error is seen, TensorBoard is causing a blocking "
|
||||
"NumPy conversion."
|
||||
)
|
||||
|
||||
tensor.numpy = mock_numpy
|
||||
|
||||
logs = {"metric": tensor}
|
||||
|
||||
cb_list.on_train_begin(logs)
|
||||
cb_list.on_epoch_begin(0, logs)
|
||||
cb_list.on_train_batch_begin(0, logs)
|
||||
cb_list.on_train_batch_end(0, logs)
|
||||
cb_list.on_epoch_end(0, logs)
|
||||
cb_list.on_train_end(logs)
|
||||
|
||||
cb_list.on_test_begin(logs)
|
||||
cb_list.on_test_batch_begin(0, logs)
|
||||
cb_list.on_test_batch_end(0, logs)
|
||||
cb_list.on_test_end(logs)
|
||||
|
||||
cb_list.on_predict_begin(logs)
|
||||
cb_list.on_predict_batch_begin(logs)
|
||||
cb_list.on_predict_batch_end(logs)
|
||||
cb_list.on_predict_end(logs)
|
||||
|
||||
|
||||
# Note that this test specifies model_type explicitly.
|
||||
class TestTensorBoardV2NonParameterizedTest(testing.TestCase):
|
||||
def setUp(self):
|
||||
super(TestTensorBoardV2NonParameterizedTest, self).setUp()
|
||||
self.logdir = os.path.join(self.get_temp_dir(), "tb")
|
||||
self.train_dir = os.path.join(self.logdir, "train")
|
||||
self.validation_dir = os.path.join(self.logdir, "validation")
|
||||
|
||||
def _get_seq_model(self):
|
||||
model = models.Sequential(
|
||||
[
|
||||
Input((10, 10, 1)),
|
||||
Conv2D(8, (3, 3)),
|
||||
Flatten(),
|
||||
Dense(1),
|
||||
]
|
||||
)
|
||||
opt = optimizers.SGD(learning_rate=0.001)
|
||||
model.compile(opt, "mse")
|
||||
return model
|
||||
|
||||
def _count_xplane_file(self, logdir):
|
||||
profile_dir = os.path.join(logdir, "plugins", "profile")
|
||||
count = 0
|
||||
for dirpath, dirnames, filenames in os.walk(profile_dir):
|
||||
del dirpath # unused
|
||||
del dirnames # unused
|
||||
for filename in filenames:
|
||||
if filename.endswith(".xplane.pb"):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def fitModelAndAssertKerasModelWritten(self, model):
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = callbacks.TensorBoard(
|
||||
self.logdir, write_graph=True, profile_batch=0
|
||||
)
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=3,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk],
|
||||
)
|
||||
summary_file = list_summaries(self.logdir)
|
||||
self.assertEqual(
|
||||
summary_file.tensors,
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag="keras"),
|
||||
},
|
||||
)
|
||||
if not model.run_eagerly:
|
||||
# There should be one train graph
|
||||
self.assertLen(summary_file.graph_defs, 1)
|
||||
for graph_def in summary_file.graph_defs:
|
||||
graph_def_str = str(graph_def)
|
||||
|
||||
# All the model layers should appear in the graphs
|
||||
for layer in model.layers:
|
||||
if "input" not in layer.name:
|
||||
self.assertIn(layer.name, graph_def_str)
|
||||
|
||||
def test_TensorBoard_write_sequential_model_no_input_shape(self):
|
||||
# TODO: Requires to_json implementation in trainer
|
||||
# model = models.Sequential(
|
||||
# [
|
||||
# Conv2D(8, (3, 3)),
|
||||
# Flatten(),
|
||||
# Dense(1),
|
||||
# ]
|
||||
# )
|
||||
# model.compile("sgd", "mse")
|
||||
# self.fitModelAndAssertKerasModelWritten(model)
|
||||
pass
|
||||
|
||||
def test_TensorBoard_write_sequential_model_with_input_shape(self):
|
||||
# TODO: Requires to_json implementation in trainer
|
||||
# model = models.Sequential(
|
||||
# [
|
||||
# Input(input_shape=(10, 10, 1)),
|
||||
# Conv2D(8, (3, 3)),
|
||||
# Flatten(),
|
||||
# Dense(1),
|
||||
# ]
|
||||
# )
|
||||
# model.compile("sgd", "mse")
|
||||
# self.fitModelAndAssertKerasModelWritten(model)
|
||||
pass
|
||||
|
||||
def test_TensorBoard_write_model(self):
|
||||
# TODO: Requires to_json implementation in trainer
|
||||
# See https://github.com/keras-team/keras/blob/ \
|
||||
# a8d4a7f1ffc9de3c5932828a107e4e95e8803fb4/ \
|
||||
# keras/engine/training.py#L3313
|
||||
# inputs = Input([10, 10, 1])
|
||||
# x = Conv2D(8, (3, 3), activation="relu")(inputs)
|
||||
# x = Flatten()(x)
|
||||
# x = Dense(1)(x)
|
||||
# model = models.Model(inputs=inputs, outputs=[x])
|
||||
# model.compile("sgd", "mse")
|
||||
# breakpoint()
|
||||
# self.fitModelAndAssertKerasModelWritten(model)
|
||||
pass
|
||||
|
||||
def test_TensorBoard_auto_trace(self):
|
||||
# TODO: Waiting for implementation for torch/jax for profiling ops
|
||||
#if backend.backend() == "jax":
|
||||
# return
|
||||
# TODO: Debug profiling for JAX
|
||||
# model = self._get_seq_model()
|
||||
# x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
# tb_cbk = callbacks.TensorBoard(
|
||||
# self.logdir, histogram_freq=1, profile_batch=1, write_graph=False
|
||||
# )
|
||||
|
||||
# model.fit(
|
||||
# x,
|
||||
# y,
|
||||
# batch_size=2,
|
||||
# epochs=2,
|
||||
# validation_data=(x, y),
|
||||
# callbacks=[tb_cbk],
|
||||
# )
|
||||
# summary_file = list_summaries(self.logdir)
|
||||
|
||||
# self.assertEqual(
|
||||
# summary_file.tensors,
|
||||
# {
|
||||
# _ObservedSummary(logdir=self.train_dir, tag="batch_1"),
|
||||
# },
|
||||
# )
|
||||
# self.assertEqual(1, self._count_xplane_file(logdir=self.logdir))
|
||||
pass
|
@ -13,5 +13,8 @@ pandas
|
||||
absl-py
|
||||
requests
|
||||
h5py
|
||||
protobuf
|
||||
google
|
||||
tensorboard-plugin-profile
|
||||
rich
|
||||
build
|
||||
build
|
||||
|
Loading…
Reference in New Issue
Block a user