Fix Test breakage on with context for dropout (#38)
This commit is contained in:
parent
1f80d8c1ed
commit
6034134d95
@ -106,10 +106,9 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
||||
|
||||
def dropout(inputs, rate, noise_shape=None, seed=None):
|
||||
seed = tf_draw_seed(seed)
|
||||
with tf.init_scope():
|
||||
return tf.nn.experimental.stateless_dropout(
|
||||
inputs,
|
||||
rate=rate,
|
||||
noise_shape=noise_shape,
|
||||
seed=seed,
|
||||
)
|
||||
return tf.nn.experimental.stateless_dropout(
|
||||
inputs,
|
||||
rate=rate,
|
||||
noise_shape=noise_shape,
|
||||
seed=seed,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user