Give trainer classes explicit names

This commit is contained in:
Francois Chollet 2023-04-20 14:59:20 -07:00
parent ff097e5b02
commit 6603a1ed46
4 changed files with 7 additions and 7 deletions

@ -10,7 +10,7 @@ from keras_core.trainers.data_adapters import data_adapter_utils
from keras_core.trainers.epoch_iterator import EpochIterator
class Trainer(base_trainer.Trainer):
class JAXTrainer(base_trainer.Trainer):
def compute_loss_and_updates(
self, trainable_variables, non_trainable_variables, x, y, sample_weight
):

@ -12,7 +12,7 @@ from keras_core.trainers.data_adapters import data_adapter_utils
from keras_core.trainers.epoch_iterator import EpochIterator
class Trainer(base_trainer.Trainer):
class TensorFlowTrainer(base_trainer.Trainer):
def __init__(self):
super().__init__()
self.train_function = None
@ -290,7 +290,7 @@ class Trainer(base_trainer.Trainer):
if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator
else:
# Create an iterator that yields batches for one epoch.
# Create an iterator that yields batches of input/target data.
epoch_iterator = TFEpochIterator(
x=x,
y=y,

@ -4,9 +4,9 @@ from keras_core.layers.layer import Layer
from keras_core.utils import summary_utils
if backend.backend() == "tensorflow":
from keras_core.backend.tensorflow.trainer import Trainer
from keras_core.backend.tensorflow.trainer import TensorFlowTrainer as Trainer
elif backend.backend() == "jax":
from keras_core.backend.jax.trainer import Trainer
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
else:
Trainer = None

@ -9,9 +9,9 @@ from keras_core import optimizers
from keras_core import testing
if backend.backend() == "jax":
from keras_core.backend.jax.trainer import Trainer
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
else:
from keras_core.backend.tensorflow.trainer import Trainer
from keras_core.backend.tensorflow.trainer import TensorFlowTrainer as Trainer
# A model is just a layer mixed in with a Trainer.