Add JAX predict flow.

This commit is contained in:
Francois Chollet 2023-04-20 13:21:41 -07:00
parent 0fa15a4b12
commit 7253945282

@ -1,4 +1,6 @@
import jax
import numpy as np
import tensorflow as tf # for nest
from keras_core import backend
from keras_core import callbacks as callbacks_module
@ -290,7 +292,7 @@ class Trainer(base_trainer.Trainer):
if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator
else:
# Create an iterator that yields batches for one epoch.
# Create an iterator that yields batches of input/target data.
epoch_iterator = EpochIterator(
x=x,
y=y,
@ -421,4 +423,70 @@ class Trainer(base_trainer.Trainer):
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
raise NotImplementedError
# Create an iterator that yields batches of input data.
epoch_iterator = EpochIterator(
x=x,
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
)
if not self.built:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
# Build model
y_pred = self(data)
break
# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
callbacks = callbacks_module.CallbackList(
callbacks,
add_history=True,
add_progbar=verbose != 0,
verbose=verbose,
epochs=1,
steps=epoch_iterator.num_batches,
model=self,
)
if not self.run_eagerly and self.jit_compile:
@jax.jit
def predict_step(
trainable_variables, non_trainable_variables, data
):
return self.stateless_call(
trainable_variables, non_trainable_variables, data
)
else:
predict_step = self.stateless_call
callbacks.on_predict_begin()
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
outputs = None
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_predict_batch_begin(step)
batch_outputs, non_trainable_variables = predict_step(
trainable_variables, non_trainable_variables, x
)
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(batch_output),
outputs,
batch_outputs,
)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(
batch_outputs, np.concatenate, outputs
)