From 7253945282b596e56f6844af305975cd2753e70f Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Thu, 20 Apr 2023 13:21:41 -0700 Subject: [PATCH] Add JAX predict flow. --- keras_core/backend/jax/trainer.py | 72 ++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/keras_core/backend/jax/trainer.py b/keras_core/backend/jax/trainer.py index e6343124a..5dc5157d0 100644 --- a/keras_core/backend/jax/trainer.py +++ b/keras_core/backend/jax/trainer.py @@ -1,4 +1,6 @@ import jax +import numpy as np +import tensorflow as tf # for nest from keras_core import backend from keras_core import callbacks as callbacks_module @@ -290,7 +292,7 @@ class Trainer(base_trainer.Trainer): if use_cached_eval_dataset: epoch_iterator = self._eval_epoch_iterator else: - # Create an iterator that yields batches for one epoch. + # Create an iterator that yields batches of input/target data. epoch_iterator = EpochIterator( x=x, y=y, @@ -421,4 +423,70 @@ class Trainer(base_trainer.Trainer): def predict( self, x, batch_size=None, verbose="auto", steps=None, callbacks=None ): - raise NotImplementedError + # Create an iterator that yields batches of input data. + epoch_iterator = EpochIterator( + x=x, + batch_size=batch_size, + steps_per_epoch=steps, + shuffle=False, + ) + + if not self.built: + # Build the model on one batch of data. + for _, data in epoch_iterator.enumerate_epoch(return_type="np"): + # Build model + y_pred = self(data) + break + + # Container that configures and calls callbacks. + if not isinstance(callbacks, callbacks_module.CallbackList): + callbacks = callbacks_module.CallbackList( + callbacks, + add_history=True, + add_progbar=verbose != 0, + verbose=verbose, + epochs=1, + steps=epoch_iterator.num_batches, + model=self, + ) + + if not self.run_eagerly and self.jit_compile: + + @jax.jit + def predict_step( + trainable_variables, non_trainable_variables, data + ): + return self.stateless_call( + trainable_variables, non_trainable_variables, data + ) + + else: + predict_step = self.stateless_call + + callbacks.on_predict_begin() + + trainable_variables = self.trainable_variables + non_trainable_variables = self.non_trainable_variables + outputs = None + for step, x in epoch_iterator.enumerate_epoch(return_type="np"): + callbacks.on_predict_batch_begin(step) + batch_outputs, non_trainable_variables = predict_step( + trainable_variables, non_trainable_variables, x + ) + if outputs is None: + outputs = tf.nest.map_structure( + lambda batch_output: [batch_output], + batch_outputs, + ) + else: + tf.__internal__.nest.map_structure_up_to( + batch_outputs, + lambda output, batch_output: output.append(batch_output), + outputs, + batch_outputs, + ) + callbacks.on_predict_batch_end(step, {"outputs": batch_outputs}) + callbacks.on_predict_end() + return tf.__internal__.nest.map_structure_up_to( + batch_outputs, np.concatenate, outputs + )