while_v2: Move more reduction ops to forward graph
When applicable, also move TensorListElementShape and TensorListLength to the forward graph as an optimization to Control Flow v2. PiperOrigin-RevId: 347699857
This commit is contained in:
parent
d592b6ca19
commit
af1a2eb1f5
@ -27,6 +27,7 @@ import shutil
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
import keras
|
||||
from keras import combinations
|
||||
from keras import keras_parameterized
|
||||
@ -580,6 +581,7 @@ class GRUV2Test(keras_parameterized.TestCase):
|
||||
outputs_trimmed = lstm(inputs[:, :masksteps])
|
||||
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
|
||||
|
||||
@tf_test_util.enable_output_all_intermediates
|
||||
def test_v1_session_behavior(self):
|
||||
with tf.compat.v1.get_default_graph().as_default():
|
||||
# See b/139132348 for more details.
|
||||
|
@ -28,6 +28,7 @@ import time
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
import keras
|
||||
from keras import keras_parameterized
|
||||
from keras import testing_utils
|
||||
@ -781,6 +782,7 @@ class LSTMV2Test(keras_parameterized.TestCase):
|
||||
outputs_trimmed = lstm(inputs[:, :masksteps])
|
||||
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
|
||||
|
||||
@tf_test_util.enable_output_all_intermediates
|
||||
def test_v1_session_behavior(self):
|
||||
with tf.compat.v1.get_default_graph().as_default():
|
||||
# See b/139132348 for more details.
|
||||
|
@ -33,6 +33,8 @@ from keras.engine import base_layer_utils
|
||||
from keras.layers import core
|
||||
from keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
|
||||
from keras.utils import generic_utils
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.training.tracking import util as trackable_util
|
||||
|
||||
@ -622,33 +624,39 @@ class BidirectionalTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_bidirectional_statefulness(self):
|
||||
# Bidirectional and stateful
|
||||
rnn = keras.layers.SimpleRNN
|
||||
samples = 2
|
||||
dim = 2
|
||||
timesteps = 2
|
||||
output_dim = 2
|
||||
mode = 'sum'
|
||||
def run_test():
|
||||
rnn = keras.layers.SimpleRNN
|
||||
samples = 2
|
||||
dim = 2
|
||||
timesteps = 2
|
||||
output_dim = 2
|
||||
mode = 'sum'
|
||||
|
||||
with self.cached_session():
|
||||
x = np.random.random((samples, timesteps, dim))
|
||||
target_dim = 2 * output_dim if mode == 'concat' else output_dim
|
||||
y = np.random.random((samples, target_dim))
|
||||
with self.cached_session():
|
||||
x = np.random.random((samples, timesteps, dim))
|
||||
target_dim = 2 * output_dim if mode == 'concat' else output_dim
|
||||
y = np.random.random((samples, target_dim))
|
||||
|
||||
inputs = keras.layers.Input(batch_shape=(1, timesteps, dim))
|
||||
bidi_rnn = keras.layers.Bidirectional(
|
||||
rnn(output_dim, stateful=True), merge_mode=mode)
|
||||
self.assertTrue(bidi_rnn.stateful)
|
||||
output = bidi_rnn(inputs)
|
||||
model = keras.models.Model(inputs, output)
|
||||
inputs = keras.layers.Input(batch_shape=(1, timesteps, dim))
|
||||
bidi_rnn = keras.layers.Bidirectional(
|
||||
rnn(output_dim, stateful=True), merge_mode=mode)
|
||||
self.assertTrue(bidi_rnn.stateful)
|
||||
output = bidi_rnn(inputs)
|
||||
model = keras.models.Model(inputs, output)
|
||||
|
||||
y_1 = model.predict(x, batch_size=1)
|
||||
model.reset_states()
|
||||
y_2 = model.predict(x, batch_size=1)
|
||||
y_1 = model.predict(x, batch_size=1)
|
||||
model.reset_states()
|
||||
y_2 = model.predict(x, batch_size=1)
|
||||
|
||||
self.assertAllClose(y_1, y_2)
|
||||
self.assertAllClose(y_1, y_2)
|
||||
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
model.fit(x, y, epochs=1, batch_size=1)
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
model.fit(x, y, epochs=1, batch_size=1)
|
||||
|
||||
if context.executing_eagerly():
|
||||
run_test()
|
||||
else:
|
||||
tf_test_util.enable_output_all_intermediates(run_test)()
|
||||
|
||||
@parameterized.parameters(['sum', 'mul', 'ave', 'concat', None])
|
||||
def test_Bidirectional_merged_value(self, merge_mode):
|
||||
|
Loading…
Reference in New Issue
Block a user