while_v2: Move more reduction ops to forward graph
When applicable, also move TensorListElementShape and TensorListLength to the forward graph as an optimization to Control Flow v2. PiperOrigin-RevId: 347699857
This commit is contained in:
parent
d592b6ca19
commit
af1a2eb1f5
@ -27,6 +27,7 @@ import shutil
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
import keras
|
||||
from keras import combinations
|
||||
from keras import keras_parameterized
|
||||
@ -580,6 +581,7 @@ class GRUV2Test(keras_parameterized.TestCase):
|
||||
outputs_trimmed = lstm(inputs[:, :masksteps])
|
||||
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
|
||||
|
||||
@tf_test_util.enable_output_all_intermediates
|
||||
def test_v1_session_behavior(self):
|
||||
with tf.compat.v1.get_default_graph().as_default():
|
||||
# See b/139132348 for more details.
|
||||
|
@ -28,6 +28,7 @@ import time
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
import keras
|
||||
from keras import keras_parameterized
|
||||
from keras import testing_utils
|
||||
@ -781,6 +782,7 @@ class LSTMV2Test(keras_parameterized.TestCase):
|
||||
outputs_trimmed = lstm(inputs[:, :masksteps])
|
||||
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
|
||||
|
||||
@tf_test_util.enable_output_all_intermediates
|
||||
def test_v1_session_behavior(self):
|
||||
with tf.compat.v1.get_default_graph().as_default():
|
||||
# See b/139132348 for more details.
|
||||
|
@ -33,6 +33,8 @@ from keras.engine import base_layer_utils
|
||||
from keras.layers import core
|
||||
from keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
|
||||
from keras.utils import generic_utils
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.training.tracking import util as trackable_util
|
||||
|
||||
@ -622,6 +624,7 @@ class BidirectionalTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_bidirectional_statefulness(self):
|
||||
# Bidirectional and stateful
|
||||
def run_test():
|
||||
rnn = keras.layers.SimpleRNN
|
||||
samples = 2
|
||||
dim = 2
|
||||
@ -650,6 +653,11 @@ class BidirectionalTest(tf.test.TestCase, parameterized.TestCase):
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
model.fit(x, y, epochs=1, batch_size=1)
|
||||
|
||||
if context.executing_eagerly():
|
||||
run_test()
|
||||
else:
|
||||
tf_test_util.enable_output_all_intermediates(run_test)()
|
||||
|
||||
@parameterized.parameters(['sum', 'mul', 'ave', 'concat', None])
|
||||
def test_Bidirectional_merged_value(self, merge_mode):
|
||||
rnn = keras.layers.LSTM
|
||||
|
Loading…
Reference in New Issue
Block a user