Fix TensorBoard.

This commit is contained in:
Francois Chollet 2023-06-04 11:36:46 -07:00
parent 699e4c3174
commit c63de3adca
2 changed files with 145 additions and 200 deletions

@ -1,8 +1,8 @@
import warnings
import logging
import os
import sys
import time
import warnings
import tensorflow.summary as summary
from tensorflow import nest
@ -16,12 +16,10 @@ from keras_core.callbacks.callback import Callback
from keras_core.layers import Embedding
from keras_core.optimizers import Optimizer
from keras_core.optimizers.schedules import learning_rate_schedule
from keras_core.utils import file_utils
@keras_core_export("keras_core.callbacks.TensorBoard")
class TensorBoard(Callback):
"""Enable visualizations for TensorBoard.
TensorBoard is a visualization tool provided with TensorFlow. A TensorFlow
@ -34,10 +32,10 @@ class TensorBoard(Callback):
* Weight histograms
* Sampled profiling
When used in `Model.evaluate` or regular validation
When used in `model.evaluate()` or regular validation
in addition to epoch summaries, there will be a summary that records
evaluation metrics vs `Model.optimizer.iterations` written. The metric names
will be prepended with `evaluation`, with `Model.optimizer.iterations` being
evaluation metrics vs `model.optimizer.iterations` written. The metric names
will be prepended with `evaluation`, with `model.optimizer.iterations` being
the step in the visualized TensorBoard.
If you have installed TensorFlow with pip, you should be able
@ -52,34 +50,37 @@ class TensorBoard(Callback):
Args:
log_dir: the path of the directory where to save the log files to be
parsed by TensorBoard. e.g. log_dir = os.path.join(working_dir,
'logs') This directory should not be reused by any other callbacks.
parsed by TensorBoard. e.g.,
`log_dir = os.path.join(working_dir, 'logs')`.
This directory should not be reused by any other callbacks.
histogram_freq: frequency (in epochs) at which to compute
weight histograms for the layers of the model. If set to 0,
histograms won't be computed. Validation data (or split) must be
specified for histogram visualizations.
write_graph: TODO: still not supported.
whether to visualize the graph in TensorBoard. The log file
can become quite large when write_graph is set to True.
write_graph: (Not supported at this time)
Whether to visualize the graph in TensorBoard.
Note that the log file can become quite large
when `write_graph` is set to `True`.
write_images: whether to write model weights to visualize as image in
TensorBoard.
write_steps_per_second: whether to log the training steps per second
into TensorBoard. This supports both epoch and batch frequency
logging.
update_freq: `'batch'` or `'epoch'` or integer. When using `'epoch'`,
update_freq: `"batch"` or `"epoch"` or integer. When using `"epoch"`,
writes the losses and metrics to TensorBoard after every epoch.
If using an integer, let's say `1000`, all metrics and losses
(including custom ones added by `Model.compile`) will be logged to
TensorBoard every 1000 batches. `'batch'` is a synonym for `1`,
TensorBoard every 1000 batches. `"batch"` is a synonym for 1,
meaning that they will be written every batch.
Note however that writing too frequently to TensorBoard can slow
down your training, especially when used with distribution
strategies as it will incur additional synchronization overhead.
Batch-level summary writing is also available via `train_step`
override. Please see
[TensorBoard Scalars tutorial](https://www.tensorflow.org/tensorboard/scalars_and_keras#batch-level_logging) # noqa: E501
[TensorBoard Scalars tutorial](
https://www.tensorflow.org/tensorboard/scalars_and_keras#batch-level_logging) # noqa: E501
for more details.
profile_batch: TODO: still not supported.
profile_batch: (Not supported at this time)
Profile the batch(es) to sample compute characteristics.
profile_batch must be a non-negative integer or a tuple of integers.
A pair of positive integers signify a range of batches to profile.
@ -118,7 +119,7 @@ class TensorBoard(Callback):
model.compile('sgd', 'mse')
# Make sure to set `update_freq=N` to log a batch-level summary every N
# batches. In addition to any `tf.summary` contained in `Model.call`,
# batches. In addition to any `tf.summary` contained in `model.call()`,
# metrics added in `Model.compile` will be logged every N batches.
tb_callback = keras_core.callbacks.TensorBoard('./logs', update_freq=1)
model.fit(x_train, y_train, callbacks=[tb_callback])
@ -628,9 +629,9 @@ def keras_model_summary(name, data, step=None):
name: A name for this summary. The summary tag used for TensorBoard will
be this name prefixed by any active name scopes.
data: A Keras Model to write.
step: Explicit `int64`-castable monotonic step value for this summary. If
omitted, this defaults to `tf.summary.experimental.get_step()`, which
must not be None.
step: Explicit `int64`-castable monotonic step value for this summary.
If omitted, this defaults to `tf.summary.experimental.get_step()`,
which must not be `None`.
Returns:
True on success, or False if no summary was written because no default
@ -638,7 +639,7 @@ def keras_model_summary(name, data, step=None):
Raises:
ValueError: if a default writer exists, but no step was provided and
`tf.summary.experimental.get_step()` is None.
`tf.summary.experimental.get_step()` is `None`.
"""
summary_metadata = SummaryMetadata()
# Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for

@ -1,5 +1,6 @@
import collections
import os
import random
import numpy as np
import tensorflow.summary as summary
@ -7,19 +8,13 @@ from tensorflow.compat.v1 import SummaryMetadata
from tensorflow.core.util import event_pb2
from tensorflow.python.lib.io import tf_record
from keras_core import backend
from keras_core import callbacks
from keras_core import layers
from keras_core import losses
from keras_core import models
from keras_core import operations as ops
from keras_core import optimizers
from keras_core import testing
from keras_core.layers import Conv2D
from keras_core.layers import Dense
from keras_core.layers import Embedding
from keras_core.layers import Flatten
from keras_core.layers import Input
from keras_core.layers import Layer
from keras_core.optimizers import schedules
# Note: this file and tensorboard in general has a dependency on tensorflow
@ -31,7 +26,7 @@ from keras_core.optimizers import schedules
_ObservedSummary = collections.namedtuple("_ObservedSummary", ("logdir", "tag"))
class _SummaryIterator(object):
class _SummaryIterator:
"""Yields `Event` protocol buffers from a given path."""
def __init__(self, path):
@ -63,19 +58,6 @@ class _SummaryFile:
self.convert_from_v2_summary_proto = False
def get_model_from_layers(model_layers, input_shape, name=None):
model = models.Sequential(name=name)
model.add(
Input(
input_shape,
dtype="float32",
)
)
for layer in model_layers:
model.add(layer)
return model
def list_summaries(logdir):
"""Read all summaries under the logdir into a `_SummaryFile`.
@ -139,32 +121,31 @@ def list_summaries(logdir):
class TestTensorBoardV2(testing.TestCase):
def setUp(self):
super(TestTensorBoardV2, self).setUp()
self.logdir = os.path.join(self.get_temp_dir(), "tb")
self.train_dir = os.path.join(self.logdir, "train")
self.validation_dir = os.path.join(self.logdir, "validation")
def _get_log_dirs(self):
logdir = os.path.join(
self.get_temp_dir(), str(random.randint(1, 1e7)), "tb"
)
train_dir = os.path.join(logdir, "train")
validation_dir = os.path.join(logdir, "validation")
return logdir, train_dir, validation_dir
def _get_model(self, compile_model=True):
layers = [
Conv2D(8, (3, 3)),
Flatten(),
Dense(1),
]
model = get_model_from_layers(layers, input_shape=(10, 10, 1))
model = models.Sequential(
[
layers.Input((10, 10, 1)),
layers.Flatten(),
layers.Dense(1),
]
)
if compile_model:
opt = optimizers.SGD(learning_rate=0.001)
model.compile(opt, "mse")
model.compile("sgd", "mse")
return model
def test_TensorBoard_default_logdir(self):
"""Regression test for cross-platform pathsep in default logdir."""
os.chdir(self.get_temp_dir())
def test_TensorBoard_basic(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = callbacks.TensorBoard() # no logdir specified
logdir, train_dir, validation_dir = self._get_log_dirs()
tb_cbk = callbacks.TensorBoard(logdir)
model.fit(
x,
@ -175,55 +156,25 @@ class TestTensorBoardV2(testing.TestCase):
callbacks=[tb_cbk],
)
summary_file = list_summaries(logdir=".")
train_dir = os.path.join(".", "logs", "train")
validation_dir = os.path.join(".", "logs", "validation")
summary_file = list_summaries(logdir)
self.assertEqual(
summary_file.scalars,
{
_ObservedSummary(logdir=train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=validation_dir, tag="epoch_loss"),
_ObservedSummary(
logdir=validation_dir, tag="evaluation_loss_vs_iterations"
),
},
)
def test_TensorBoard_basic(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = callbacks.TensorBoard(self.logdir)
model.fit(
x,
y,
batch_size=2,
epochs=2,
validation_data=(x, y),
callbacks=[tb_cbk],
)
summary_file = list_summaries(self.logdir)
self.assertEqual(
summary_file.scalars,
{
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=self.validation_dir, tag="epoch_loss"),
_ObservedSummary(
logdir=self.validation_dir,
logdir=validation_dir,
tag="evaluation_loss_vs_iterations",
),
},
)
def test_TensorBoard_across_invocations(self):
"""Regression test for summary writer resource use-after-free.
See: <https://github.com/tensorflow/tensorflow/issues/25707>
"""
"""Regression test for summary writer resource use-after-free."""
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = callbacks.TensorBoard(self.logdir)
logdir, train_dir, validation_dir = self._get_log_dirs()
tb_cbk = callbacks.TensorBoard(logdir)
for _ in (1, 2):
model.fit(
@ -235,14 +186,14 @@ class TestTensorBoardV2(testing.TestCase):
callbacks=[tb_cbk],
)
summary_file = list_summaries(self.logdir)
summary_file = list_summaries(logdir)
self.assertEqual(
summary_file.scalars,
{
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=self.validation_dir, tag="epoch_loss"),
_ObservedSummary(logdir=train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=validation_dir, tag="epoch_loss"),
_ObservedSummary(
logdir=self.validation_dir,
logdir=validation_dir,
tag="evaluation_loss_vs_iterations",
),
},
@ -251,12 +202,12 @@ class TestTensorBoardV2(testing.TestCase):
def test_TensorBoard_no_spurious_event_files(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = callbacks.TensorBoard(self.logdir)
logdir, train_dir, _ = self._get_log_dirs()
tb_cbk = callbacks.TensorBoard(logdir)
model.fit(x, y, batch_size=2, epochs=2, callbacks=[tb_cbk])
events_file_run_basenames = set()
for dirpath, _, filenames in os.walk(self.train_dir):
for dirpath, _, filenames in os.walk(train_dir):
if any(fn.startswith("events.out.") for fn in filenames):
events_file_run_basenames.add(os.path.basename(dirpath))
self.assertEqual(events_file_run_basenames, {"train"})
@ -264,7 +215,8 @@ class TestTensorBoardV2(testing.TestCase):
def test_TensorBoard_batch_metrics(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = callbacks.TensorBoard(self.logdir, update_freq=1)
logdir, train_dir, validation_dir = self._get_log_dirs()
tb_cbk = callbacks.TensorBoard(logdir, update_freq=1)
model.fit(
x,
@ -275,15 +227,15 @@ class TestTensorBoardV2(testing.TestCase):
callbacks=[tb_cbk],
)
summary_file = list_summaries(self.logdir)
summary_file = list_summaries(logdir)
self.assertEqual(
summary_file.scalars,
{
_ObservedSummary(logdir=self.train_dir, tag="batch_loss"),
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=self.validation_dir, tag="epoch_loss"),
_ObservedSummary(logdir=train_dir, tag="batch_loss"),
_ObservedSummary(logdir=train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=validation_dir, tag="epoch_loss"),
_ObservedSummary(
logdir=self.validation_dir,
logdir=validation_dir,
tag="evaluation_loss_vs_iterations",
),
},
@ -293,7 +245,7 @@ class TestTensorBoardV2(testing.TestCase):
model = self._get_model(compile_model=False)
opt = optimizers.SGD(schedules.CosineDecay(0.01, 1))
model.compile(opt, "mse")
logdir, train_dir, _ = self._get_log_dirs()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
model.fit(
@ -301,17 +253,15 @@ class TestTensorBoardV2(testing.TestCase):
y,
batch_size=2,
epochs=2,
callbacks=[callbacks.TensorBoard(self.logdir)],
callbacks=[callbacks.TensorBoard(logdir)],
)
summary_file = list_summaries(self.logdir)
summary_file = list_summaries(logdir)
self.assertEqual(
summary_file.scalars,
{
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
_ObservedSummary(
logdir=self.train_dir, tag="epoch_learning_rate"
),
_ObservedSummary(logdir=train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=train_dir, tag="epoch_learning_rate"),
},
)
@ -319,7 +269,7 @@ class TestTensorBoardV2(testing.TestCase):
model = self._get_model(compile_model=False)
opt = optimizers.SGD(schedules.CosineDecay(0.01, 1))
model.compile(opt, "mse")
logdir, train_dir, _ = self._get_log_dirs()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
model.fit(
@ -330,7 +280,7 @@ class TestTensorBoardV2(testing.TestCase):
verbose=0,
callbacks=[
callbacks.TensorBoard(
self.logdir,
logdir,
update_freq=1,
profile_batch=0,
write_steps_per_second=True,
@ -338,20 +288,18 @@ class TestTensorBoardV2(testing.TestCase):
],
)
summary_file = list_summaries(self.logdir)
summary_file = list_summaries(logdir)
self.assertEqual(
summary_file.scalars,
{
_ObservedSummary(logdir=self.train_dir, tag="batch_loss"),
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=train_dir, tag="batch_loss"),
_ObservedSummary(logdir=train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=train_dir, tag="epoch_learning_rate"),
_ObservedSummary(
logdir=self.train_dir, tag="epoch_learning_rate"
logdir=train_dir, tag="epoch_steps_per_second"
),
_ObservedSummary(
logdir=self.train_dir, tag="epoch_steps_per_second"
),
_ObservedSummary(
logdir=self.train_dir, tag="batch_steps_per_second"
logdir=train_dir, tag="batch_steps_per_second"
),
},
)
@ -359,7 +307,8 @@ class TestTensorBoardV2(testing.TestCase):
def test_TensorBoard_weight_histograms(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = callbacks.TensorBoard(self.logdir, histogram_freq=1)
logdir, train_dir, validation_dir = self._get_log_dirs()
tb_cbk = callbacks.TensorBoard(logdir, histogram_freq=1)
model.fit(
x,
@ -369,32 +318,41 @@ class TestTensorBoardV2(testing.TestCase):
validation_data=(x, y),
callbacks=[tb_cbk],
)
summary_file = list_summaries(self.logdir)
summary_file = list_summaries(logdir)
self.assertEqual(
summary_file.scalars,
{
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=self.validation_dir, tag="epoch_loss"),
_ObservedSummary(logdir=train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=validation_dir, tag="epoch_loss"),
_ObservedSummary(
logdir=self.validation_dir,
logdir=validation_dir,
tag="evaluation_loss_vs_iterations",
),
},
)
self.assertEqual(
self._strip_layer_names(summary_file.histograms, "sequential"),
{_ObservedSummary(logdir=self.train_dir, tag="histogram")},
{_ObservedSummary(logdir=train_dir, tag="histogram")},
)
def test_TensorBoard_weight_images(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
logdir, train_dir, validation_dir = self._get_log_dirs()
tb_cbk = callbacks.TensorBoard(
self.logdir, histogram_freq=1, write_images=True
logdir, histogram_freq=1, write_images=True
)
model_type = "sequential"
model = models.Sequential(
[
layers.Input((10, 10, 1)),
layers.Conv2D(3, 10),
layers.GlobalAveragePooling2D(),
layers.Dense(1),
]
)
model.compile("sgd", "mse")
model.fit(
x,
y,
@ -403,15 +361,15 @@ class TestTensorBoardV2(testing.TestCase):
validation_data=(x, y),
callbacks=[tb_cbk],
)
summary_file = list_summaries(self.logdir)
summary_file = list_summaries(logdir)
self.assertEqual(
summary_file.scalars,
{
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=self.validation_dir, tag="epoch_loss"),
_ObservedSummary(logdir=train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=validation_dir, tag="epoch_loss"),
_ObservedSummary(
logdir=self.validation_dir,
logdir=validation_dir,
tag="evaluation_loss_vs_iterations",
),
},
@ -419,13 +377,13 @@ class TestTensorBoardV2(testing.TestCase):
self.assertEqual(
self._strip_layer_names(summary_file.histograms, model_type),
{
_ObservedSummary(logdir=self.train_dir, tag="histogram"),
_ObservedSummary(logdir=train_dir, tag="histogram"),
},
)
expected_image_summaries = {
_ObservedSummary(logdir=self.train_dir, tag="image"),
_ObservedSummary(logdir=self.train_dir, tag="bias/image"),
_ObservedSummary(logdir=self.train_dir, tag="kernel/image"),
_ObservedSummary(logdir=train_dir, tag="image"),
_ObservedSummary(logdir=train_dir, tag="bias/image"),
_ObservedSummary(logdir=train_dir, tag="kernel/image"),
}
self.assertEqual(
self._strip_variable_names(summary_file.images),
@ -433,18 +391,20 @@ class TestTensorBoardV2(testing.TestCase):
)
def test_TensorBoard_projector_callback(self):
layers = [
Embedding(10, 10, name="test_embedding"),
Dense(10, activation="relu"),
Dense(1, activation="sigmoid"),
]
model = get_model_from_layers(layers, input_shape=(10,))
model = models.Sequential(
[
layers.Input((10,)),
layers.Embedding(10, 10, name="test_embedding"),
layers.Dense(1, activation="sigmoid"),
]
)
model.compile(
optimizer="adam", loss=losses.BinaryCrossentropy(from_logits=True)
)
x, y = np.ones((10, 10)), np.ones((10, 10))
logdir, _, _ = self._get_log_dirs()
tb_cbk = callbacks.TensorBoard(
self.logdir,
logdir,
embeddings_freq=1,
embeddings_metadata={"test_embedding": "metadata.tsv"},
)
@ -458,7 +418,7 @@ class TestTensorBoardV2(testing.TestCase):
callbacks=[tb_cbk],
)
with open(os.path.join(self.logdir, "projector_config.pbtxt")) as f:
with open(os.path.join(logdir, "projector_config.pbtxt")) as f:
self.assertEqual(
f.readlines(),
[
@ -489,40 +449,45 @@ class TestTensorBoardV2(testing.TestCase):
metadata=metadata,
)
class LayerWithSummary(Layer):
class LayerWithSummary(layers.Layer):
def call(self, x):
scalar_v2_mock("custom_summary", ops.sum(x))
return x
model = get_model_from_layers(
[LayerWithSummary()], input_shape=(5,), name="model"
model = models.Sequential(
[
layers.Input((5,)),
LayerWithSummary(),
]
)
model.compile("sgd", "mse", jit_compile=False) # summary ops can't xla
tb_cbk = callbacks.TensorBoard(self.logdir, update_freq=1)
# summary ops not compatible with XLA
model.compile("sgd", "mse", jit_compile=False)
logdir, train_dir, validation_dir = self._get_log_dirs()
tb_cbk = callbacks.TensorBoard(logdir, update_freq=1)
x, y = np.ones((10, 5)), np.ones((10, 5))
model.fit(
x, y, batch_size=2, validation_data=(x, y), callbacks=[tb_cbk]
)
summary_file = list_summaries(self.logdir)
summary_file = list_summaries(logdir)
# TODO: tensorflow will tag with model/layer_with_summary/custom_summary
# Jax will only use custom_summary tag
self.assertEqual(
self._strip_to_only_final_name(summary_file.scalars),
{
_ObservedSummary(logdir=self.train_dir, tag="batch_loss"),
_ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=self.validation_dir, tag="epoch_loss"),
_ObservedSummary(logdir=train_dir, tag="batch_loss"),
_ObservedSummary(logdir=train_dir, tag="epoch_loss"),
_ObservedSummary(logdir=validation_dir, tag="epoch_loss"),
_ObservedSummary(
logdir=self.validation_dir,
logdir=validation_dir,
tag="evaluation_loss_vs_iterations",
),
_ObservedSummary(
logdir=self.train_dir,
logdir=train_dir,
tag="custom_summary",
),
_ObservedSummary(
logdir=self.validation_dir,
logdir=validation_dir,
tag="custom_summary",
),
},
@ -530,20 +495,20 @@ class TestTensorBoardV2(testing.TestCase):
# self.assertEqual(
# summary_file.scalars,
# {
# _ObservedSummary(logdir=self.train_dir, tag="batch_loss"),
# _ObservedSummary(logdir=self.train_dir, tag="epoch_loss"),
# _ObservedSummary(logdir=self.validation_dir,
# _ObservedSummary(logdir=train_dir, tag="batch_loss"),
# _ObservedSummary(logdir=train_dir, tag="epoch_loss"),
# _ObservedSummary(logdir=validation_dir,
# tag="epoch_loss"),
# _ObservedSummary(
# logdir=self.validation_dir,
# logdir=validation_dir,
# tag="evaluation_loss_vs_iterations",
# ),
# _ObservedSummary(
# logdir=self.train_dir,
# logdir=train_dir,
# tag="model/layer_with_summary/custom_summary",
# ),
# _ObservedSummary(
# logdir=self.validation_dir,
# logdir=validation_dir,
# tag="model/layer_with_summary/custom_summary",
# ),
# },
@ -618,13 +583,13 @@ class TestTensorBoardV2(testing.TestCase):
return result
def test_TensorBoard_non_blocking(self):
model = models.Sequential([Dense(1)])
logdir, _, _ = self._get_log_dirs()
model = models.Sequential([layers.Dense(1)])
model.optimizer = optimizers.Adam()
tb = callbacks.TensorBoard(self.logdir)
tb = callbacks.TensorBoard(logdir)
cb_list = callbacks.CallbackList(
[tb], model=model, epochs=1, steps=100, verbose=0
)
tensor = ops.convert_to_tensor(1.0)
def mock_numpy():
@ -654,28 +619,6 @@ class TestTensorBoardV2(testing.TestCase):
cb_list.on_predict_batch_end(logs)
cb_list.on_predict_end(logs)
# Note that this test specifies model_type explicitly.
class TestTensorBoardV2NonParameterizedTest(testing.TestCase):
def setUp(self):
super(TestTensorBoardV2NonParameterizedTest, self).setUp()
self.logdir = os.path.join(self.get_temp_dir(), "tb")
self.train_dir = os.path.join(self.logdir, "train")
self.validation_dir = os.path.join(self.logdir, "validation")
def _get_seq_model(self):
model = models.Sequential(
[
Input((10, 10, 1)),
Conv2D(8, (3, 3)),
Flatten(),
Dense(1),
]
)
opt = optimizers.SGD(learning_rate=0.001)
model.compile(opt, "mse")
return model
def _count_xplane_file(self, logdir):
profile_dir = os.path.join(logdir, "plugins", "profile")
count = 0
@ -689,8 +632,9 @@ class TestTensorBoardV2NonParameterizedTest(testing.TestCase):
def fitModelAndAssertKerasModelWritten(self, model):
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
logdir, train_dir, validation_dir = self._get_log_dirs()
tb_cbk = callbacks.TensorBoard(
self.logdir, write_graph=True, profile_batch=0
logdir, write_graph=True, profile_batch=0
)
model.fit(
x,
@ -700,11 +644,11 @@ class TestTensorBoardV2NonParameterizedTest(testing.TestCase):
validation_data=(x, y),
callbacks=[tb_cbk],
)
summary_file = list_summaries(self.logdir)
summary_file = list_summaries(logdir)
self.assertEqual(
summary_file.tensors,
{
_ObservedSummary(logdir=self.train_dir, tag="keras"),
_ObservedSummary(logdir=train_dir, tag="keras"),
},
)
if not model.run_eagerly:
@ -762,13 +706,13 @@ class TestTensorBoardV2NonParameterizedTest(testing.TestCase):
def test_TensorBoard_auto_trace(self):
# TODO: Waiting for implementation for torch/jax for profiling ops
#if backend.backend() == "jax":
# if backend.backend() == "jax":
# return
# TODO: Debug profiling for JAX
# model = self._get_seq_model()
# x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
# tb_cbk = callbacks.TensorBoard(
# self.logdir, histogram_freq=1, profile_batch=1, write_graph=False
# logdir, histogram_freq=1, profile_batch=1, write_graph=False
# )
# model.fit(
@ -779,13 +723,13 @@ class TestTensorBoardV2NonParameterizedTest(testing.TestCase):
# validation_data=(x, y),
# callbacks=[tb_cbk],
# )
# summary_file = list_summaries(self.logdir)
# summary_file = list_summaries(logdir)
# self.assertEqual(
# summary_file.tensors,
# {
# _ObservedSummary(logdir=self.train_dir, tag="batch_1"),
# _ObservedSummary(logdir=train_dir, tag="batch_1"),
# },
# )
# self.assertEqual(1, self._count_xplane_file(logdir=self.logdir))
# self.assertEqual(1, self._count_xplane_file(logdir=logdir))
pass