Use tf.select instead of tf.where (compat TF 0.11)
This commit is contained in:
parent
914d976801
commit
30fa61d457
@ -1707,7 +1707,7 @@ def rnn(step_function, inputs, initial_states,
|
||||
for input, mask_t in zip(input_list, mask_list):
|
||||
output, new_states = step_function(input, states + constants)
|
||||
|
||||
# tf.where needs its condition tensor
|
||||
# tf.select needs its condition tensor
|
||||
# to be the same shape as its two
|
||||
# result tensors, but in our case
|
||||
# the condition (mask) tensor is
|
||||
@ -1725,16 +1725,16 @@ def rnn(step_function, inputs, initial_states,
|
||||
else:
|
||||
prev_output = successive_outputs[-1]
|
||||
|
||||
output = tf.where(tiled_mask_t, output, prev_output)
|
||||
output = tf.select(tiled_mask_t, output, prev_output)
|
||||
|
||||
return_states = []
|
||||
for state, new_state in zip(states, new_states):
|
||||
# (see earlier comment for tile explanation)
|
||||
tiled_mask_t = tf.tile(mask_t,
|
||||
stack([1, tf.shape(new_state)[1]]))
|
||||
return_states.append(tf.where(tiled_mask_t,
|
||||
new_state,
|
||||
state))
|
||||
return_states.append(tf.select(tiled_mask_t,
|
||||
new_state,
|
||||
state))
|
||||
states = return_states
|
||||
successive_outputs.append(output)
|
||||
successive_states.append(states)
|
||||
@ -1795,8 +1795,8 @@ def rnn(step_function, inputs, initial_states,
|
||||
new_state.set_shape(state.get_shape())
|
||||
tiled_mask_t = tf.tile(mask_t,
|
||||
stack([1, tf.shape(output)[1]]))
|
||||
output = tf.where(tiled_mask_t, output, states[0])
|
||||
new_states = [tf.where(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
|
||||
output = tf.select(tiled_mask_t, output, states[0])
|
||||
new_states = [tf.select(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
|
||||
output_ta_t = output_ta_t.write(time, output)
|
||||
return (time + 1, output_ta_t) + tuple(new_states)
|
||||
else:
|
||||
@ -1921,7 +1921,7 @@ def elu(x, alpha=1.):
|
||||
if alpha == 1:
|
||||
return res
|
||||
else:
|
||||
return tf.where(x > 0, res, alpha * res)
|
||||
return tf.select(x > 0, res, alpha * res)
|
||||
|
||||
|
||||
def softmax(x):
|
||||
@ -2384,9 +2384,9 @@ def random_uniform(shape, low=0.0, high=1.0, dtype=_FLOATX, seed=None):
|
||||
def random_binomial(shape, p=0.0, dtype=_FLOATX, seed=None):
|
||||
if seed is None:
|
||||
seed = np.random.randint(10e6)
|
||||
return tf.where(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p,
|
||||
tf.ones(shape, dtype=dtype),
|
||||
tf.zeros(shape, dtype=dtype))
|
||||
return tf.select(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p,
|
||||
tf.ones(shape, dtype=dtype),
|
||||
tf.zeros(shape, dtype=dtype))
|
||||
|
||||
|
||||
# CTC
|
||||
|
Loading…
Reference in New Issue
Block a user