while_v2: Move more reduction ops to forward graph
When applicable, also move TensorListElementShape and TensorListLength to the forward graph as an optimization to Control Flow v2. PiperOrigin-RevId: 347699857
This commit is contained in:
parent
d592b6ca19
commit
af1a2eb1f5
@ -27,6 +27,7 @@ import shutil
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
|
from tensorflow.python.framework import test_util as tf_test_util
|
||||||
import keras
|
import keras
|
||||||
from keras import combinations
|
from keras import combinations
|
||||||
from keras import keras_parameterized
|
from keras import keras_parameterized
|
||||||
@ -580,6 +581,7 @@ class GRUV2Test(keras_parameterized.TestCase):
|
|||||||
outputs_trimmed = lstm(inputs[:, :masksteps])
|
outputs_trimmed = lstm(inputs[:, :masksteps])
|
||||||
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
|
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
|
||||||
|
|
||||||
|
@tf_test_util.enable_output_all_intermediates
|
||||||
def test_v1_session_behavior(self):
|
def test_v1_session_behavior(self):
|
||||||
with tf.compat.v1.get_default_graph().as_default():
|
with tf.compat.v1.get_default_graph().as_default():
|
||||||
# See b/139132348 for more details.
|
# See b/139132348 for more details.
|
||||||
|
@ -28,6 +28,7 @@ import time
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||||
|
from tensorflow.python.framework import test_util as tf_test_util
|
||||||
import keras
|
import keras
|
||||||
from keras import keras_parameterized
|
from keras import keras_parameterized
|
||||||
from keras import testing_utils
|
from keras import testing_utils
|
||||||
@ -781,6 +782,7 @@ class LSTMV2Test(keras_parameterized.TestCase):
|
|||||||
outputs_trimmed = lstm(inputs[:, :masksteps])
|
outputs_trimmed = lstm(inputs[:, :masksteps])
|
||||||
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
|
self.assertAllClose(outputs_masked[:, -masksteps:], outputs_trimmed)
|
||||||
|
|
||||||
|
@tf_test_util.enable_output_all_intermediates
|
||||||
def test_v1_session_behavior(self):
|
def test_v1_session_behavior(self):
|
||||||
with tf.compat.v1.get_default_graph().as_default():
|
with tf.compat.v1.get_default_graph().as_default():
|
||||||
# See b/139132348 for more details.
|
# See b/139132348 for more details.
|
||||||
|
@ -33,6 +33,8 @@ from keras.engine import base_layer_utils
|
|||||||
from keras.layers import core
|
from keras.layers import core
|
||||||
from keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
|
from keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
|
||||||
from keras.utils import generic_utils
|
from keras.utils import generic_utils
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import test_util as tf_test_util
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.training.tracking import util as trackable_util
|
from tensorflow.python.training.tracking import util as trackable_util
|
||||||
|
|
||||||
@ -622,33 +624,39 @@ class BidirectionalTest(tf.test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
def test_bidirectional_statefulness(self):
|
def test_bidirectional_statefulness(self):
|
||||||
# Bidirectional and stateful
|
# Bidirectional and stateful
|
||||||
rnn = keras.layers.SimpleRNN
|
def run_test():
|
||||||
samples = 2
|
rnn = keras.layers.SimpleRNN
|
||||||
dim = 2
|
samples = 2
|
||||||
timesteps = 2
|
dim = 2
|
||||||
output_dim = 2
|
timesteps = 2
|
||||||
mode = 'sum'
|
output_dim = 2
|
||||||
|
mode = 'sum'
|
||||||
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
x = np.random.random((samples, timesteps, dim))
|
x = np.random.random((samples, timesteps, dim))
|
||||||
target_dim = 2 * output_dim if mode == 'concat' else output_dim
|
target_dim = 2 * output_dim if mode == 'concat' else output_dim
|
||||||
y = np.random.random((samples, target_dim))
|
y = np.random.random((samples, target_dim))
|
||||||
|
|
||||||
inputs = keras.layers.Input(batch_shape=(1, timesteps, dim))
|
inputs = keras.layers.Input(batch_shape=(1, timesteps, dim))
|
||||||
bidi_rnn = keras.layers.Bidirectional(
|
bidi_rnn = keras.layers.Bidirectional(
|
||||||
rnn(output_dim, stateful=True), merge_mode=mode)
|
rnn(output_dim, stateful=True), merge_mode=mode)
|
||||||
self.assertTrue(bidi_rnn.stateful)
|
self.assertTrue(bidi_rnn.stateful)
|
||||||
output = bidi_rnn(inputs)
|
output = bidi_rnn(inputs)
|
||||||
model = keras.models.Model(inputs, output)
|
model = keras.models.Model(inputs, output)
|
||||||
|
|
||||||
y_1 = model.predict(x, batch_size=1)
|
y_1 = model.predict(x, batch_size=1)
|
||||||
model.reset_states()
|
model.reset_states()
|
||||||
y_2 = model.predict(x, batch_size=1)
|
y_2 = model.predict(x, batch_size=1)
|
||||||
|
|
||||||
self.assertAllClose(y_1, y_2)
|
self.assertAllClose(y_1, y_2)
|
||||||
|
|
||||||
model.compile(loss='mse', optimizer='sgd')
|
model.compile(loss='mse', optimizer='sgd')
|
||||||
model.fit(x, y, epochs=1, batch_size=1)
|
model.fit(x, y, epochs=1, batch_size=1)
|
||||||
|
|
||||||
|
if context.executing_eagerly():
|
||||||
|
run_test()
|
||||||
|
else:
|
||||||
|
tf_test_util.enable_output_all_intermediates(run_test)()
|
||||||
|
|
||||||
@parameterized.parameters(['sum', 'mul', 'ave', 'concat', None])
|
@parameterized.parameters(['sum', 'mul', 'ave', 'concat', None])
|
||||||
def test_Bidirectional_merged_value(self, merge_mode):
|
def test_Bidirectional_merged_value(self, merge_mode):
|
||||||
|
Loading…
Reference in New Issue
Block a user