Use tf.select instead of tf.where (compat TF 0.11)

This commit is contained in:
Francois Chollet 2016-12-16 16:37:55 -08:00
parent 914d976801
commit 30fa61d457

@ -1707,7 +1707,7 @@ def rnn(step_function, inputs, initial_states,
for input, mask_t in zip(input_list, mask_list):
output, new_states = step_function(input, states + constants)
# tf.where needs its condition tensor
# tf.select needs its condition tensor
# to be the same shape as its two
# result tensors, but in our case
# the condition (mask) tensor is
@ -1725,16 +1725,16 @@ def rnn(step_function, inputs, initial_states,
else:
prev_output = successive_outputs[-1]
output = tf.where(tiled_mask_t, output, prev_output)
output = tf.select(tiled_mask_t, output, prev_output)
return_states = []
for state, new_state in zip(states, new_states):
# (see earlier comment for tile explanation)
tiled_mask_t = tf.tile(mask_t,
stack([1, tf.shape(new_state)[1]]))
return_states.append(tf.where(tiled_mask_t,
new_state,
state))
return_states.append(tf.select(tiled_mask_t,
new_state,
state))
states = return_states
successive_outputs.append(output)
successive_states.append(states)
@ -1795,8 +1795,8 @@ def rnn(step_function, inputs, initial_states,
new_state.set_shape(state.get_shape())
tiled_mask_t = tf.tile(mask_t,
stack([1, tf.shape(output)[1]]))
output = tf.where(tiled_mask_t, output, states[0])
new_states = [tf.where(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
output = tf.select(tiled_mask_t, output, states[0])
new_states = [tf.select(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t) + tuple(new_states)
else:
@ -1921,7 +1921,7 @@ def elu(x, alpha=1.):
if alpha == 1:
return res
else:
return tf.where(x > 0, res, alpha * res)
return tf.select(x > 0, res, alpha * res)
def softmax(x):
@ -2384,9 +2384,9 @@ def random_uniform(shape, low=0.0, high=1.0, dtype=_FLOATX, seed=None):
def random_binomial(shape, p=0.0, dtype=_FLOATX, seed=None):
if seed is None:
seed = np.random.randint(10e6)
return tf.where(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p,
tf.ones(shape, dtype=dtype),
tf.zeros(shape, dtype=dtype))
return tf.select(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p,
tf.ones(shape, dtype=dtype),
tf.zeros(shape, dtype=dtype))
# CTC