Fix Test breakage on with context for dropout (#38)
This commit is contained in:
parent
1f80d8c1ed
commit
6034134d95
@ -106,10 +106,9 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
|
|||||||
|
|
||||||
def dropout(inputs, rate, noise_shape=None, seed=None):
|
def dropout(inputs, rate, noise_shape=None, seed=None):
|
||||||
seed = tf_draw_seed(seed)
|
seed = tf_draw_seed(seed)
|
||||||
with tf.init_scope():
|
return tf.nn.experimental.stateless_dropout(
|
||||||
return tf.nn.experimental.stateless_dropout(
|
inputs,
|
||||||
inputs,
|
rate=rate,
|
||||||
rate=rate,
|
noise_shape=noise_shape,
|
||||||
noise_shape=noise_shape,
|
seed=seed,
|
||||||
seed=seed,
|
)
|
||||||
)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user