Give trainer classes explicit names

This commit is contained in:
Francois Chollet 2023-04-20 14:59:20 -07:00
parent ff097e5b02
commit 6603a1ed46
4 changed files with 7 additions and 7 deletions

@ -10,7 +10,7 @@ from keras_core.trainers.data_adapters import data_adapter_utils
from keras_core.trainers.epoch_iterator import EpochIterator from keras_core.trainers.epoch_iterator import EpochIterator
class Trainer(base_trainer.Trainer): class JAXTrainer(base_trainer.Trainer):
def compute_loss_and_updates( def compute_loss_and_updates(
self, trainable_variables, non_trainable_variables, x, y, sample_weight self, trainable_variables, non_trainable_variables, x, y, sample_weight
): ):

@ -12,7 +12,7 @@ from keras_core.trainers.data_adapters import data_adapter_utils
from keras_core.trainers.epoch_iterator import EpochIterator from keras_core.trainers.epoch_iterator import EpochIterator
class Trainer(base_trainer.Trainer): class TensorFlowTrainer(base_trainer.Trainer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.train_function = None self.train_function = None
@ -290,7 +290,7 @@ class Trainer(base_trainer.Trainer):
if use_cached_eval_dataset: if use_cached_eval_dataset:
epoch_iterator = self._eval_epoch_iterator epoch_iterator = self._eval_epoch_iterator
else: else:
# Create an iterator that yields batches for one epoch. # Create an iterator that yields batches of input/target data.
epoch_iterator = TFEpochIterator( epoch_iterator = TFEpochIterator(
x=x, x=x,
y=y, y=y,

@ -4,9 +4,9 @@ from keras_core.layers.layer import Layer
from keras_core.utils import summary_utils from keras_core.utils import summary_utils
if backend.backend() == "tensorflow": if backend.backend() == "tensorflow":
from keras_core.backend.tensorflow.trainer import Trainer from keras_core.backend.tensorflow.trainer import TensorFlowTrainer as Trainer
elif backend.backend() == "jax": elif backend.backend() == "jax":
from keras_core.backend.jax.trainer import Trainer from keras_core.backend.jax.trainer import JAXTrainer as Trainer
else: else:
Trainer = None Trainer = None

@ -9,9 +9,9 @@ from keras_core import optimizers
from keras_core import testing from keras_core import testing
if backend.backend() == "jax": if backend.backend() == "jax":
from keras_core.backend.jax.trainer import Trainer from keras_core.backend.jax.trainer import JAXTrainer as Trainer
else: else:
from keras_core.backend.tensorflow.trainer import Trainer from keras_core.backend.tensorflow.trainer import TensorFlowTrainer as Trainer
# A model is just a layer mixed in with a Trainer. # A model is just a layer mixed in with a Trainer.