while_v2: Move more reduction ops to forward graph

When applicable, also move TensorListElementShape and
TensorListLength to the forward graph as an optimization
to Control Flow v2.

PiperOrigin-RevId: 347699857
This commit is contained in:
Victor de Souza 2020-12-15 15:01:54 -08:00 committed by TensorFlower Gardener
parent d592b6ca19
commit af1a2eb1f5
3 changed files with 34 additions and 22 deletions

@ -27,6 +27,7 @@ import shutil
from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import test_util as tf_test_util
import keras
from keras import combinations
from keras import keras_parameterized
@ -580,6 +581,7 @@ class GRUV2Test(keras_parameterized.TestCase):
outputs_trimmed = lstm(inputs[:, :masksteps])
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
@tf_test_util.enable_output_all_intermediates
def test_v1_session_behavior(self):
with tf.compat.v1.get_default_graph().as_default():
# See b/139132348 for more details.

@ -28,6 +28,7 @@ import time
from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import test_util as tf_test_util
import keras
from keras import keras_parameterized
from keras import testing_utils
@ -781,6 +782,7 @@ class LSTMV2Test(keras_parameterized.TestCase):
outputs_trimmed = lstm(inputs[:, :masksteps])
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
@tf_test_util.enable_output_all_intermediates
def test_v1_session_behavior(self):
with tf.compat.v1.get_default_graph().as_default():
# See b/139132348 for more details.

@ -33,6 +33,8 @@ from keras.engine import base_layer_utils
from keras.layers import core
from keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
from keras.utils import generic_utils
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.training.tracking import util as trackable_util
@ -622,6 +624,7 @@ class BidirectionalTest(tf.test.TestCase, parameterized.TestCase):
def test_bidirectional_statefulness(self):
# Bidirectional and stateful
def run_test():
rnn = keras.layers.SimpleRNN
samples = 2
dim = 2
@ -650,6 +653,11 @@ class BidirectionalTest(tf.test.TestCase, parameterized.TestCase):
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, epochs=1, batch_size=1)
if context.executing_eagerly():
run_test()
else:
tf_test_util.enable_output_all_intermediates(run_test)()
@parameterized.parameters(['sum', 'mul', 'ave', 'concat', None])
def test_Bidirectional_merged_value(self, merge_mode):
rnn = keras.layers.LSTM