while_v2: Move more reduction ops to forward graph

When applicable, also move TensorListElementShape and
TensorListLength to the forward graph as an optimization
to Control Flow v2.

PiperOrigin-RevId: 347699857
This commit is contained in:
Victor de Souza 2020-12-15 15:01:54 -08:00 committed by TensorFlower Gardener
parent d592b6ca19
commit af1a2eb1f5
3 changed files with 34 additions and 22 deletions

@ -27,6 +27,7 @@ import shutil
from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import test_util as tf_test_util
import keras
from keras import combinations
from keras import keras_parameterized
@ -580,6 +581,7 @@ class GRUV2Test(keras_parameterized.TestCase):
outputs_trimmed = lstm(inputs[:, :masksteps])
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
@tf_test_util.enable_output_all_intermediates
def test_v1_session_behavior(self):
with tf.compat.v1.get_default_graph().as_default():
# See b/139132348 for more details.

@ -28,6 +28,7 @@ import time
from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import test_util as tf_test_util
import keras
from keras import keras_parameterized
from keras import testing_utils
@ -781,6 +782,7 @@ class LSTMV2Test(keras_parameterized.TestCase):
outputs_trimmed = lstm(inputs[:, :masksteps])
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
@tf_test_util.enable_output_all_intermediates
def test_v1_session_behavior(self):
with tf.compat.v1.get_default_graph().as_default():
# See b/139132348 for more details.

@ -33,6 +33,8 @@ from keras.engine import base_layer_utils
from keras.layers import core
from keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
from keras.utils import generic_utils
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.training.tracking import util as trackable_util
@ -622,33 +624,39 @@ class BidirectionalTest(tf.test.TestCase, parameterized.TestCase):
def test_bidirectional_statefulness(self):
# Bidirectional and stateful
rnn = keras.layers.SimpleRNN
samples = 2
dim = 2
timesteps = 2
output_dim = 2
mode = 'sum'
def run_test():
rnn = keras.layers.SimpleRNN
samples = 2
dim = 2
timesteps = 2
output_dim = 2
mode = 'sum'
with self.cached_session():
x = np.random.random((samples, timesteps, dim))
target_dim = 2 * output_dim if mode == 'concat' else output_dim
y = np.random.random((samples, target_dim))
with self.cached_session():
x = np.random.random((samples, timesteps, dim))
target_dim = 2 * output_dim if mode == 'concat' else output_dim
y = np.random.random((samples, target_dim))
inputs = keras.layers.Input(batch_shape=(1, timesteps, dim))
bidi_rnn = keras.layers.Bidirectional(
rnn(output_dim, stateful=True), merge_mode=mode)
self.assertTrue(bidi_rnn.stateful)
output = bidi_rnn(inputs)
model = keras.models.Model(inputs, output)
inputs = keras.layers.Input(batch_shape=(1, timesteps, dim))
bidi_rnn = keras.layers.Bidirectional(
rnn(output_dim, stateful=True), merge_mode=mode)
self.assertTrue(bidi_rnn.stateful)
output = bidi_rnn(inputs)
model = keras.models.Model(inputs, output)
y_1 = model.predict(x, batch_size=1)
model.reset_states()
y_2 = model.predict(x, batch_size=1)
y_1 = model.predict(x, batch_size=1)
model.reset_states()
y_2 = model.predict(x, batch_size=1)
self.assertAllClose(y_1, y_2)
self.assertAllClose(y_1, y_2)
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, epochs=1, batch_size=1)
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, epochs=1, batch_size=1)
if context.executing_eagerly():
run_test()
else:
tf_test_util.enable_output_all_intermediates(run_test)()
@parameterized.parameters(['sum', 'mul', 'ave', 'concat', None])
def test_Bidirectional_merged_value(self, merge_mode):