keras/benchmarks/layer_benchmark/base_benchmark.py
Chen Qian 35082455e5 Add more layer benchmarks (#245)
* initials

* add benchmark class

* add train benchmark

* Add conv benchmark

* flag

* fix comments

* docstring

* remove redundant flags

* More benchmarks

* better

* add more benchmarks

* remove weird files

---------

Co-authored-by: chenmoneygithub <chenmoney@chenmoney-gpu3.us-west1-a.c.keras-team-gcp.internal>
Co-authored-by: chenmoneygithub <chenmoney@chenmoney-gpu-4.us-west1-a.c.keras-team-gcp.internal>
2023-06-05 11:05:12 -07:00

259 lines
8.4 KiB
Python

import time
import numpy as np
import tensorflow as tf
from absl import flags
import keras_core
FLAGS = flags.FLAGS
flags.DEFINE_string(
"benchmark_name",
None,
"The name of benchmark to run. If None, all benchmarks in the file will be "
"run.",
)
flags.DEFINE_integer(
"num_samples",
1000,
"Number of input data samples.",
)
flags.DEFINE_integer(
"batch_size",
20,
"Batch size of data.",
)
flags.DEFINE_bool(
"jit_compile",
True,
"If True, the benchmark will run with XLA compilation.",
)
class BenchmarkMetricsCallback:
def __init__(self, start_batch=1, stop_batch=None):
self.start_batch = start_batch
self.stop_batch = stop_batch
self.state = {}
def on_train_batch_begin(self, batch, logs=None):
if batch == self.start_batch:
self.state["benchmark_begin"] = time.time()
def on_train_batch_end(self, batch, logs=None):
if batch == self.stop_batch:
self.state["benchmark_end"] = time.time()
throughput = (self.stop_batch - self.start_batch + 1) / (
self.state["benchmark_end"] - self.state["benchmark_begin"]
)
self.state["throughput"] = throughput
def on_predict_batch_begin(self, batch, logs=None):
if batch == self.start_batch:
self.state["benchmark_begin"] = time.time()
def on_predict_batch_end(self, batch, logs=None):
if batch == self.stop_batch:
self.state["benchmark_end"] = time.time()
throughput = (self.stop_batch - self.start_batch + 1) / (
self.state["benchmark_end"] - self.state["benchmark_begin"]
)
self.state["throughput"] = throughput
class KerasCoreBenchmarkMetricsCallback(keras_core.callbacks.Callback):
def __init__(self, start_batch=1, stop_batch=None):
self._callback = BenchmarkMetricsCallback(start_batch, stop_batch)
def on_train_batch_begin(self, batch, logs=None):
self._callback.on_train_batch_begin(batch, logs)
def on_train_batch_end(self, batch, logs=None):
self._callback.on_train_batch_end(batch, logs)
def on_predict_batch_begin(self, batch, logs=None):
self._callback.on_predict_batch_begin(batch, logs)
def on_predict_batch_end(self, batch, logs=None):
self._callback.on_predict_batch_end(batch, logs)
class TFKerasBenchmarkMetricsCallback(tf.keras.callbacks.Callback):
def __init__(self, start_batch=1, stop_batch=None):
self._callback = BenchmarkMetricsCallback(start_batch, stop_batch)
def on_train_batch_begin(self, batch, logs=None):
self._callback.on_train_batch_begin(batch, logs)
def on_train_batch_end(self, batch, logs=None):
self._callback.on_train_batch_end(batch, logs)
def on_predict_batch_begin(self, batch, logs=None):
self._callback.on_predict_batch_begin(batch, logs)
def on_predict_batch_end(self, batch, logs=None):
self._callback.on_predict_batch_end(batch, logs)
class LayerBenchmark:
def __init__(
self,
layer_name,
init_args,
input_shape,
flat_call_inputs=True,
jit_compile=True,
):
self.layer_name = layer_name
self.input_shape = input_shape
_keras_core_layer_class = getattr(keras_core.layers, layer_name)
_tf_keras_layer_class = getattr(tf.keras.layers, layer_name)
self._keras_core_layer = _keras_core_layer_class(**init_args)
self._tf_keras_layer = _tf_keras_layer_class(**init_args)
self._keras_core_model = self._build_keras_core_model(
input_shape, flat_call_inputs
)
self._tf_keras_model = self._build_tf_keras_model(
input_shape, flat_call_inputs
)
self._keras_core_model.compile(
loss="mse", optimizer="sgd", jit_compile=jit_compile
)
self._tf_keras_model.compile(
loss="mse", optimizer="sgd", jit_compile=jit_compile
)
self.flat_call_inputs = flat_call_inputs
self.jit_compile = jit_compile
self.input_shape = input_shape
def _build_keras_core_model(self, input_shape, flat_call_inputs=True):
inputs = []
if not isinstance(input_shape[0], (tuple, list)):
input_shape = [input_shape]
for shape in input_shape:
inputs.append(keras_core.Input(shape=shape))
if flat_call_inputs:
outputs = self._keras_core_layer(*inputs)
else:
outputs = self._keras_core_layer(inputs)
return keras_core.Model(inputs=inputs, outputs=outputs)
def _build_tf_keras_model(self, input_shape, flat_call_inputs=True):
inputs = []
if not isinstance(input_shape[0], (tuple, list)):
input_shape = [input_shape]
for shape in input_shape:
inputs.append(tf.keras.Input(shape=shape))
if flat_call_inputs:
outputs = self._tf_keras_layer(*inputs)
else:
outputs = self._tf_keras_layer(inputs)
return tf.keras.Model(inputs=inputs, outputs=outputs)
def benchmark_predict(self, num_samples, batch_size):
if isinstance(self.input_shape[0], (tuple, list)):
# The layer has multiple inputs.
data = []
for data_shape in self.input_shape:
data_shape = [num_samples] + list(data_shape)
data.append(np.random.normal(size=data_shape))
else:
data_shape = [num_samples] + list(self.input_shape)
data = np.random.normal(size=data_shape)
num_iterations = num_samples // batch_size - 1
callback = KerasCoreBenchmarkMetricsCallback(stop_batch=num_iterations)
tf_keras_callback = TFKerasBenchmarkMetricsCallback(
stop_batch=num_iterations
)
self._keras_core_model.predict(
data,
batch_size=batch_size,
callbacks=[callback],
)
self._tf_keras_model.predict(
data,
batch_size=batch_size,
callbacks=[tf_keras_callback],
)
keras_core_throughput = (
callback._callback.state["throughput"] * batch_size
)
tf_keras_throughput = (
tf_keras_callback._callback.state["throughput"] * batch_size
)
print(
f"Keras Core throughput of forward pass of {self.layer_name}: "
f"{keras_core_throughput:.2f} samples/sec."
)
print(
f"TF Keras throughput of forward pass of {self.layer_name}: "
f"{tf_keras_throughput:.2f} samples/sec."
)
def benchmark_train(self, num_samples, batch_size):
if isinstance(self.input_shape[0], (tuple, list)):
# The layer has multiple inputs.
data = []
for data_shape in self.input_shape:
data_shape = [num_samples] + list(data_shape)
data.append(np.random.normal(size=data_shape))
else:
data_shape = [num_samples] + list(self.input_shape)
data = [np.random.normal(size=data_shape)]
if self.flat_call_inputs:
# Scale by a small factor to avoid zero gradients.
label = np.array(self._keras_core_layer(*data)) * 1.001
else:
label = np.array(self._keras_core_layer(data)) * 1.001
num_iterations = num_samples // batch_size - 1
callback = KerasCoreBenchmarkMetricsCallback(stop_batch=num_iterations)
tf_keras_callback = TFKerasBenchmarkMetricsCallback(
stop_batch=num_iterations
)
self._keras_core_model.fit(
data,
label,
batch_size=batch_size,
callbacks=[callback],
)
self._tf_keras_model.fit(
data,
label,
batch_size=batch_size,
callbacks=[tf_keras_callback],
)
keras_core_throughput = (
callback._callback.state["throughput"] * batch_size
)
tf_keras_throughput = (
tf_keras_callback._callback.state["throughput"] * batch_size
)
print(
f"Keras Core throughput of forward & backward pass of "
f"{self.layer_name}: {keras_core_throughput:.2f} samples/sec."
)
print(
f"TF Keras throughput of forward & backward pass of "
f"{self.layer_name}: {tf_keras_throughput:.2f} samples/sec."
)