Fix formatting. (#19098)
This commit is contained in:
parent
b9db6e44c7
commit
1d0008bc70
@ -633,13 +633,17 @@ class JAXTrainer(base_trainer.Trainer):
|
||||
v.value for v in self.non_trainable_variables
|
||||
]
|
||||
self._purge_model_variables(
|
||||
trainable_variables=False, optimizer_variables=False, metric_variables=False
|
||||
trainable_variables=False,
|
||||
optimizer_variables=False,
|
||||
metric_variables=False,
|
||||
)
|
||||
outputs = None
|
||||
for step, x in epoch_iterator.enumerate_epoch():
|
||||
state = (trainable_variables, non_trainable_variables)
|
||||
callbacks.on_predict_batch_begin(step)
|
||||
batch_outputs, non_trainable_variables = self.predict_function(state, x)
|
||||
batch_outputs, non_trainable_variables = self.predict_function(
|
||||
state, x
|
||||
)
|
||||
outputs = append_to_outputs(batch_outputs, outputs)
|
||||
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
|
||||
if self.stop_predicting:
|
||||
@ -776,7 +780,9 @@ class JAXTrainer(base_trainer.Trainer):
|
||||
v.value for v in self.non_trainable_variables
|
||||
]
|
||||
state = (trainable_variables, non_trainable_variables)
|
||||
batch_outputs, non_trainable_variables = self.predict_function(state, [(x,)])
|
||||
batch_outputs, non_trainable_variables = self.predict_function(
|
||||
state, [(x,)]
|
||||
)
|
||||
self._jax_state = {
|
||||
"non_trainable_variables": non_trainable_variables,
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
from keras.random.random import categorical
|
||||
from keras.random.random import categorical
|
||||
from keras.random.random import dropout
|
||||
from keras.random.random import gamma
|
||||
from keras.random.random import normal
|
||||
|
@ -1156,14 +1156,15 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
|
||||
|
||||
@pytest.mark.requires_trainable_backend
|
||||
def test_rng_updated_during_predict(self):
|
||||
|
||||
class TestTimeDropout(layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.random_generator = keras.random.SeedGenerator()
|
||||
|
||||
def call(self, x):
|
||||
return keras.random.dropout(x, rate=0.5, seed=self.random_generator)
|
||||
return keras.random.dropout(
|
||||
x, rate=0.5, seed=self.random_generator
|
||||
)
|
||||
|
||||
inputs = layers.Input((20,))
|
||||
outputs = TestTimeDropout()(inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user