Add JAX predict flow.
This commit is contained in:
parent
0fa15a4b12
commit
7253945282
@ -1,4 +1,6 @@
|
||||
import jax
|
||||
import numpy as np
|
||||
import tensorflow as tf # for nest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import callbacks as callbacks_module
|
||||
@ -290,7 +292,7 @@ class Trainer(base_trainer.Trainer):
|
||||
if use_cached_eval_dataset:
|
||||
epoch_iterator = self._eval_epoch_iterator
|
||||
else:
|
||||
# Create an iterator that yields batches for one epoch.
|
||||
# Create an iterator that yields batches of input/target data.
|
||||
epoch_iterator = EpochIterator(
|
||||
x=x,
|
||||
y=y,
|
||||
@ -421,4 +423,70 @@ class Trainer(base_trainer.Trainer):
|
||||
def predict(
|
||||
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
|
||||
):
|
||||
raise NotImplementedError
|
||||
# Create an iterator that yields batches of input data.
|
||||
epoch_iterator = EpochIterator(
|
||||
x=x,
|
||||
batch_size=batch_size,
|
||||
steps_per_epoch=steps,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
if not self.built:
|
||||
# Build the model on one batch of data.
|
||||
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
|
||||
# Build model
|
||||
y_pred = self(data)
|
||||
break
|
||||
|
||||
# Container that configures and calls callbacks.
|
||||
if not isinstance(callbacks, callbacks_module.CallbackList):
|
||||
callbacks = callbacks_module.CallbackList(
|
||||
callbacks,
|
||||
add_history=True,
|
||||
add_progbar=verbose != 0,
|
||||
verbose=verbose,
|
||||
epochs=1,
|
||||
steps=epoch_iterator.num_batches,
|
||||
model=self,
|
||||
)
|
||||
|
||||
if not self.run_eagerly and self.jit_compile:
|
||||
|
||||
@jax.jit
|
||||
def predict_step(
|
||||
trainable_variables, non_trainable_variables, data
|
||||
):
|
||||
return self.stateless_call(
|
||||
trainable_variables, non_trainable_variables, data
|
||||
)
|
||||
|
||||
else:
|
||||
predict_step = self.stateless_call
|
||||
|
||||
callbacks.on_predict_begin()
|
||||
|
||||
trainable_variables = self.trainable_variables
|
||||
non_trainable_variables = self.non_trainable_variables
|
||||
outputs = None
|
||||
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
|
||||
callbacks.on_predict_batch_begin(step)
|
||||
batch_outputs, non_trainable_variables = predict_step(
|
||||
trainable_variables, non_trainable_variables, x
|
||||
)
|
||||
if outputs is None:
|
||||
outputs = tf.nest.map_structure(
|
||||
lambda batch_output: [batch_output],
|
||||
batch_outputs,
|
||||
)
|
||||
else:
|
||||
tf.__internal__.nest.map_structure_up_to(
|
||||
batch_outputs,
|
||||
lambda output, batch_output: output.append(batch_output),
|
||||
outputs,
|
||||
batch_outputs,
|
||||
)
|
||||
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
|
||||
callbacks.on_predict_end()
|
||||
return tf.__internal__.nest.map_structure_up_to(
|
||||
batch_outputs, np.concatenate, outputs
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user