Refactor layer benchmark (#333)

* some benchmark setup

* fix something

* fix rnn benchmark

* Fix format

---------

Co-authored-by: chenmoneygithub <chenmoney@chenmoney-gpu-4.us-west1-a.c.keras-team-gcp.internal>
This commit is contained in:
Chen Qian 2023-06-12 16:02:14 -07:00 committed by Francois Chollet
parent dc8b81091c
commit 39bbd2a03d
12 changed files with 174 additions and 82 deletions

@ -1,4 +1,4 @@
""" Benchmark activation layers.
"""Benchmark activation layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
@ -6,8 +6,8 @@ flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.activation_benchmark \
--benchmark_name=benchmark_elu \
--num_samples=8192 \
--batch_size=1024 \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
@ -38,6 +38,10 @@ def benchmark_elu(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_prelu(
@ -58,6 +62,10 @@ def benchmark_prelu(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_relu(
@ -78,6 +86,10 @@ def benchmark_relu(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_leaky_relu(
@ -98,6 +110,10 @@ def benchmark_leaky_relu(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_softmax(
@ -118,6 +134,10 @@ def benchmark_softmax(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
BENCHMARK_NAMES = {
@ -136,7 +156,7 @@ def main(_):
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return

@ -1,4 +1,4 @@
""" Benchmark attention layers.
"""Benchmark attention layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
@ -6,8 +6,8 @@ flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.attention_benchmark \
--benchmark_name=benchmark_attention \
--num_samples=8192 \
--batch_size=1024 \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
@ -115,7 +115,7 @@ def main(_):
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return

@ -108,15 +108,28 @@ class LayerBenchmark:
input_shape,
flat_call_inputs=True,
jit_compile=True,
keras_core_layer=None,
tf_keras_layer=None,
):
self.layer_name = layer_name
self.input_shape = input_shape
_keras_core_layer_class = getattr(keras_core.layers, layer_name)
_tf_keras_layer_class = getattr(tf.keras.layers, layer_name)
self._keras_core_layer = _keras_core_layer_class(**init_args)
self._tf_keras_layer = _tf_keras_layer_class(**init_args)
if keras_core_layer is None:
# Sometimes you want to initialize the keras_core layer and tf_keras
# layer in a different way. For example, `Bidirectional` layer,
# which takes in `keras_core.layers.Layer` and
# `tf.keras.layer.Layer` separately.
self._keras_core_layer = _keras_core_layer_class(**init_args)
else:
self._keras_core_layer = keras_core_layer
if tf_keras_layer is None:
self._tf_keras_layer = _tf_keras_layer_class(**init_args)
else:
self._tf_keras_layer = tf_keras_layer
self.input_shape = input_shape
self._keras_core_model = self._build_keras_core_model(
input_shape, flat_call_inputs
)

@ -1,4 +1,4 @@
""" Benchmark conv layers.
"""Benchmark conv layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
@ -6,8 +6,8 @@ flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.conv_benchmark \
--benchmark_name=benchmark_conv2D \
--num_samples=2000 \
--batch_size=20 \
--num_samples=2046 \
--batch_size=256 \
--jit_compile=True
```
"""
@ -28,13 +28,13 @@ def benchmark_conv1D(
):
layer_name = "Conv1D"
init_args = {
"filters": 16,
"filters": 64,
"kernel_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 16],
input_shape=[1024, 256],
jit_compile=jit_compile,
)
@ -118,7 +118,7 @@ def benchmark_depthwise_conv1D(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 4],
input_shape=[256, 64],
jit_compile=jit_compile,
)
@ -175,7 +175,7 @@ def benchmark_separable_conv1D(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 4],
input_shape=[256, 64],
jit_compile=jit_compile,
)
@ -226,13 +226,13 @@ def benchmark_conv1D_transpose(
):
layer_name = "Conv1DTranspose"
init_args = {
"filters": 16,
"kernel_size": 2,
"filters": 32,
"kernel_size": 4,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 4],
input_shape=[256, 256],
jit_compile=jit_compile,
)

@ -1,4 +1,4 @@
""" Benchmark core layers.
"""Benchmark core layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
@ -6,8 +6,8 @@ flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.core_benchmark \
--benchmark_name=benchmark_dense \
--num_samples=8192 \
--batch_size=1024 \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
@ -27,7 +27,7 @@ def benchmark_dense(
jit_compile=True,
):
layer_name = "Dense"
init_args = {"units": 128}
init_args = {"units": 256}
benchmark = LayerBenchmark(
layer_name,
init_args,
@ -54,12 +54,12 @@ def benchmark_einsum_dense(
layer_name = "EinsumDense"
init_args = {
"equation": "abc,cd->abd",
"output_shape": (None, 128),
"output_shape": (None, 256),
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[64, 32],
input_shape=[256, 256],
jit_compile=jit_compile,
)
@ -81,17 +81,19 @@ def benchmark_embedding(
):
layer_name = "Embedding"
init_args = {
"input_dim": 30,
"output_shape": 128,
"input_dim": 128,
"output_dim": 256,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[64, 32],
input_shape=[
256,
],
jit_compile=jit_compile,
)
data = np.random.randint(30, size=(num_samples, 32))
data = [np.random.randint(30, size=(num_samples, 256))]
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
@ -108,6 +110,7 @@ def benchmark_embedding(
BENCHMARK_NAMES = {
"benchmark_dense": benchmark_dense,
"benchmark_einsum_dense": benchmark_einsum_dense,
"benchmark_embedding": benchmark_embedding,
}
@ -118,7 +121,7 @@ def main(_):
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return

@ -1,4 +1,4 @@
""" Benchmark merge layers.
"""Benchmark merge layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
@ -6,8 +6,8 @@ flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.merge_benchmark \
--benchmark_name=benchmark_add \
--num_samples=8192 \
--batch_size=1024 \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
@ -40,6 +40,11 @@ def benchmark_add(
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_average(
num_samples,
@ -61,6 +66,11 @@ def benchmark_average(
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_concatenate(
num_samples,
@ -82,6 +92,11 @@ def benchmark_concatenate(
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_dot(
num_samples,
@ -89,7 +104,7 @@ def benchmark_dot(
jit_compile=True,
):
layer_name = "Dot"
init_args = {}
init_args = {"axes": [2, 1]}
benchmark = LayerBenchmark(
layer_name,
init_args,
@ -103,6 +118,11 @@ def benchmark_dot(
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_maximum(
num_samples,
@ -124,6 +144,11 @@ def benchmark_maximum(
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_minimum(
num_samples,
@ -145,6 +170,11 @@ def benchmark_minimum(
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_multiply(
num_samples,
@ -156,7 +186,7 @@ def benchmark_multiply(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[[256, 256], [256, 32]],
input_shape=[[256, 64], [256, 64]],
flat_call_inputs=False,
jit_compile=jit_compile,
)
@ -166,6 +196,11 @@ def benchmark_multiply(
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_subtract(
num_samples,
@ -187,6 +222,11 @@ def benchmark_subtract(
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
BENCHMARK_NAMES = {
"benchmark_add": benchmark_add,
@ -207,7 +247,7 @@ def main(_):
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return

@ -1,4 +1,4 @@
""" Benchmark normalization layers.
"""Benchmark normalization layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
@ -6,8 +6,8 @@ flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.normalization_benchmark \
--benchmark_name=benchmark_batch_normalization \
--num_samples=2400 \
--batch_size=300 \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
@ -82,7 +82,7 @@ def benchmark_layer_normalization(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 4],
input_shape=[256, 128, 4],
jit_compile=jit_compile,
)
@ -107,7 +107,7 @@ def benchmark_unit_normalization(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 4],
input_shape=[256, 128, 4],
jit_compile=jit_compile,
)
@ -116,6 +116,11 @@ def benchmark_unit_normalization(
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
BENCHMARK_NAMES = {
"benchmark_batch_normalization": benchmark_batch_normalization,
@ -132,7 +137,7 @@ def main(_):
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return

@ -1,4 +1,4 @@
""" Benchmark pooling layers.
"""Benchmark pooling layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
@ -33,7 +33,7 @@ def benchmark_average_pooling1d(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 3],
input_shape=[1024, 256],
jit_compile=jit_compile,
)
@ -87,7 +87,7 @@ def benchmark_average_pooling3d(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[64, 64, 64, 3],
input_shape=[64, 64, 32, 3],
jit_compile=jit_compile,
)
@ -114,7 +114,7 @@ def benchmark_max_pooling1d(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 3],
input_shape=[1024, 256],
jit_compile=jit_compile,
)
@ -168,7 +168,7 @@ def benchmark_max_pooling3d(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[64, 64, 64, 3],
input_shape=[64, 64, 32, 3],
jit_compile=jit_compile,
)
@ -193,7 +193,7 @@ def benchmark_global_average_pooling1d(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 3],
input_shape=[1024, 256],
jit_compile=jit_compile,
)
@ -243,7 +243,7 @@ def benchmark_global_average_pooling3d(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[64, 64, 64, 3],
input_shape=[64, 64, 32, 3],
jit_compile=jit_compile,
)
@ -268,7 +268,7 @@ def benchmark_global_max_pooling1d(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 3],
input_shape=[1024, 256],
jit_compile=jit_compile,
)
@ -318,7 +318,7 @@ def benchmark_global_max_pooling3d(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[64, 64, 64, 3],
input_shape=[64, 64, 32, 3],
jit_compile=jit_compile,
)
@ -356,7 +356,7 @@ def main(_):
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return

@ -1,4 +1,4 @@
""" Benchmark regularization layers.
"""Benchmark regularization layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
@ -52,7 +52,7 @@ def benchmark_gaussian_dropout(
batch_size,
jit_compile=True,
):
layer_name = "GaussionDropout"
layer_name = "GaussianDropout"
init_args = {
"rate": 0.5,
}
@ -79,7 +79,7 @@ def benchmark_gaussian_noise(
batch_size,
jit_compile=True,
):
layer_name = "GaussionNoise"
layer_name = "GaussianNoise"
init_args = {
"stddev": 0.5,
}
@ -199,7 +199,7 @@ def main(_):
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return

@ -1,4 +1,4 @@
""" Benchmark reshaping layers.
"""Benchmark reshaping layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
@ -30,7 +30,7 @@ def benchmark_cropping1d(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 3],
input_shape=[1024, 256],
jit_compile=jit_compile,
)
@ -127,7 +127,7 @@ def benchmark_permute(
):
layer_name = "Permute"
init_args = {
"dim": (2, 1),
"dims": (2, 1),
}
benchmark = LayerBenchmark(
layer_name,
@ -182,7 +182,7 @@ def benchmark_up_sampling2d(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 3],
input_shape=[128, 128, 3],
jit_compile=jit_compile,
)
@ -207,7 +207,7 @@ def benchmark_up_sampling3d(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 32, 32, 3],
input_shape=[32, 16, 16, 3],
jit_compile=jit_compile,
)
@ -319,7 +319,7 @@ def main(_):
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return

@ -1,4 +1,4 @@
""" Benchmark rnn layers.
"""Benchmark rnn layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
@ -12,6 +12,7 @@ python3 -m benchmarks.layer_benchmark.rnn_benchmark \
```
"""
import tensorflow as tf
from absl import app
from absl import flags
@ -62,7 +63,7 @@ def benchmark_conv_lstm2d(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 64, 64, 3],
input_shape=[32, 32, 32, 3],
jit_compile=jit_compile,
)
@ -84,13 +85,13 @@ def benchmark_conv_lstm3d(
):
layer_name = "ConvLSTM3D"
init_args = {
"filters": 16,
"filters": 8,
"kernel_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[16, 32, 32, 16, 3],
input_shape=[8, 16, 16, 16, 3],
jit_compile=jit_compile,
)
@ -117,7 +118,7 @@ def benchmark_gru(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 256],
input_shape=[256, 256],
jit_compile=jit_compile,
)
@ -144,7 +145,7 @@ def benchmark_lstm(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 256],
input_shape=[256, 256],
jit_compile=jit_compile,
)
@ -171,7 +172,7 @@ def benchmark_simple_rnn(
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 256],
input_shape=[256, 256],
jit_compile=jit_compile,
)
@ -192,14 +193,18 @@ def benchmark_bidirectional(
jit_compile=True,
):
layer_name = "Bidirectional"
init_args = {
"layer": keras_core.layers.LSTM(32),
}
init_args = {}
keras_core_layer = keras_core.layers.Bidirectional(
keras_core.layers.LSTM(32)
)
tf_keras_layer = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32))
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 256],
input_shape=[256, 256],
jit_compile=jit_compile,
keras_core_layer=keras_core_layer,
tf_keras_layer=tf_keras_layer,
)
benchmark.benchmark_predict(
@ -219,14 +224,20 @@ def benchmark_time_distributed(
jit_compile=True,
):
layer_name = "TimeDistributed"
init_args = {
"layer": keras_core.layers.Conv2D(64, (3, 3)),
}
init_args = {}
keras_core_layer = keras_core.layers.TimeDistributed(
keras_core.layers.Conv2D(16, (3, 3))
)
tf_keras_layer = tf.keras.layers.TimeDistributed(
tf.keras.layers.Conv2D(16, (3, 3))
)
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[10, 128, 128, 3],
input_shape=[10, 32, 32, 3],
jit_compile=jit_compile,
keras_core_layer=keras_core_layer,
tf_keras_layer=tf_keras_layer,
)
benchmark.benchmark_predict(
@ -259,7 +270,7 @@ def main(_):
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return

@ -56,8 +56,8 @@ class ZeroPadding1D(Layer):
def compute_output_shape(self, input_shape):
output_shape = list(input_shape)
if input_shape[1] is not None:
input_shape[1] += self.padding[0] + self.padding[1]
if output_shape[1] is not None:
output_shape[1] += self.padding[0] + self.padding[1]
return tuple(output_shape)
def call(self, inputs):