Fix formatting. (#19098)

This commit is contained in:
hertschuh 2024-01-24 13:50:55 -08:00 committed by GitHub
parent b9db6e44c7
commit 1d0008bc70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 6 deletions

@ -633,13 +633,17 @@ class JAXTrainer(base_trainer.Trainer):
v.value for v in self.non_trainable_variables
]
self._purge_model_variables(
trainable_variables=False, optimizer_variables=False, metric_variables=False
trainable_variables=False,
optimizer_variables=False,
metric_variables=False,
)
outputs = None
for step, x in epoch_iterator.enumerate_epoch():
state = (trainable_variables, non_trainable_variables)
callbacks.on_predict_batch_begin(step)
batch_outputs, non_trainable_variables = self.predict_function(state, x)
batch_outputs, non_trainable_variables = self.predict_function(
state, x
)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
if self.stop_predicting:
@ -776,7 +780,9 @@ class JAXTrainer(base_trainer.Trainer):
v.value for v in self.non_trainable_variables
]
state = (trainable_variables, non_trainable_variables)
batch_outputs, non_trainable_variables = self.predict_function(state, [(x,)])
batch_outputs, non_trainable_variables = self.predict_function(
state, [(x,)]
)
self._jax_state = {
"non_trainable_variables": non_trainable_variables,
}

@ -1,4 +1,4 @@
from keras.random.random import categorical
from keras.random.random import categorical
from keras.random.random import dropout
from keras.random.random import gamma
from keras.random.random import normal

@ -1156,14 +1156,15 @@ class TestTrainer(testing.TestCase, parameterized.TestCase):
@pytest.mark.requires_trainable_backend
def test_rng_updated_during_predict(self):
class TestTimeDropout(layers.Layer):
def __init__(self):
super().__init__()
self.random_generator = keras.random.SeedGenerator()
def call(self, x):
return keras.random.dropout(x, rate=0.5, seed=self.random_generator)
return keras.random.dropout(
x, rate=0.5, seed=self.random_generator
)
inputs = layers.Input((20,))
outputs = TestTimeDropout()(inputs)