Add JAX predict flow.
This commit is contained in:
parent
0fa15a4b12
commit
7253945282
@ -1,4 +1,6 @@
|
|||||||
import jax
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf # for nest
|
||||||
|
|
||||||
from keras_core import backend
|
from keras_core import backend
|
||||||
from keras_core import callbacks as callbacks_module
|
from keras_core import callbacks as callbacks_module
|
||||||
@ -290,7 +292,7 @@ class Trainer(base_trainer.Trainer):
|
|||||||
if use_cached_eval_dataset:
|
if use_cached_eval_dataset:
|
||||||
epoch_iterator = self._eval_epoch_iterator
|
epoch_iterator = self._eval_epoch_iterator
|
||||||
else:
|
else:
|
||||||
# Create an iterator that yields batches for one epoch.
|
# Create an iterator that yields batches of input/target data.
|
||||||
epoch_iterator = EpochIterator(
|
epoch_iterator = EpochIterator(
|
||||||
x=x,
|
x=x,
|
||||||
y=y,
|
y=y,
|
||||||
@ -421,4 +423,70 @@ class Trainer(base_trainer.Trainer):
|
|||||||
def predict(
|
def predict(
|
||||||
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
|
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
|
||||||
):
|
):
|
||||||
raise NotImplementedError
|
# Create an iterator that yields batches of input data.
|
||||||
|
epoch_iterator = EpochIterator(
|
||||||
|
x=x,
|
||||||
|
batch_size=batch_size,
|
||||||
|
steps_per_epoch=steps,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.built:
|
||||||
|
# Build the model on one batch of data.
|
||||||
|
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
|
||||||
|
# Build model
|
||||||
|
y_pred = self(data)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Container that configures and calls callbacks.
|
||||||
|
if not isinstance(callbacks, callbacks_module.CallbackList):
|
||||||
|
callbacks = callbacks_module.CallbackList(
|
||||||
|
callbacks,
|
||||||
|
add_history=True,
|
||||||
|
add_progbar=verbose != 0,
|
||||||
|
verbose=verbose,
|
||||||
|
epochs=1,
|
||||||
|
steps=epoch_iterator.num_batches,
|
||||||
|
model=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.run_eagerly and self.jit_compile:
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def predict_step(
|
||||||
|
trainable_variables, non_trainable_variables, data
|
||||||
|
):
|
||||||
|
return self.stateless_call(
|
||||||
|
trainable_variables, non_trainable_variables, data
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
predict_step = self.stateless_call
|
||||||
|
|
||||||
|
callbacks.on_predict_begin()
|
||||||
|
|
||||||
|
trainable_variables = self.trainable_variables
|
||||||
|
non_trainable_variables = self.non_trainable_variables
|
||||||
|
outputs = None
|
||||||
|
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
|
||||||
|
callbacks.on_predict_batch_begin(step)
|
||||||
|
batch_outputs, non_trainable_variables = predict_step(
|
||||||
|
trainable_variables, non_trainable_variables, x
|
||||||
|
)
|
||||||
|
if outputs is None:
|
||||||
|
outputs = tf.nest.map_structure(
|
||||||
|
lambda batch_output: [batch_output],
|
||||||
|
batch_outputs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tf.__internal__.nest.map_structure_up_to(
|
||||||
|
batch_outputs,
|
||||||
|
lambda output, batch_output: output.append(batch_output),
|
||||||
|
outputs,
|
||||||
|
batch_outputs,
|
||||||
|
)
|
||||||
|
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
|
||||||
|
callbacks.on_predict_end()
|
||||||
|
return tf.__internal__.nest.map_structure_up_to(
|
||||||
|
batch_outputs, np.concatenate, outputs
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user