Give trainer classes explicit names
This commit is contained in:
parent
ff097e5b02
commit
6603a1ed46
@ -10,7 +10,7 @@ from keras_core.trainers.data_adapters import data_adapter_utils
|
||||
from keras_core.trainers.epoch_iterator import EpochIterator
|
||||
|
||||
|
||||
class Trainer(base_trainer.Trainer):
|
||||
class JAXTrainer(base_trainer.Trainer):
|
||||
def compute_loss_and_updates(
|
||||
self, trainable_variables, non_trainable_variables, x, y, sample_weight
|
||||
):
|
||||
|
@ -12,7 +12,7 @@ from keras_core.trainers.data_adapters import data_adapter_utils
|
||||
from keras_core.trainers.epoch_iterator import EpochIterator
|
||||
|
||||
|
||||
class Trainer(base_trainer.Trainer):
|
||||
class TensorFlowTrainer(base_trainer.Trainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.train_function = None
|
||||
@ -290,7 +290,7 @@ class Trainer(base_trainer.Trainer):
|
||||
if use_cached_eval_dataset:
|
||||
epoch_iterator = self._eval_epoch_iterator
|
||||
else:
|
||||
# Create an iterator that yields batches for one epoch.
|
||||
# Create an iterator that yields batches of input/target data.
|
||||
epoch_iterator = TFEpochIterator(
|
||||
x=x,
|
||||
y=y,
|
||||
|
@ -4,9 +4,9 @@ from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import summary_utils
|
||||
|
||||
if backend.backend() == "tensorflow":
|
||||
from keras_core.backend.tensorflow.trainer import Trainer
|
||||
from keras_core.backend.tensorflow.trainer import TensorFlowTrainer as Trainer
|
||||
elif backend.backend() == "jax":
|
||||
from keras_core.backend.jax.trainer import Trainer
|
||||
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
|
||||
else:
|
||||
Trainer = None
|
||||
|
||||
|
@ -9,9 +9,9 @@ from keras_core import optimizers
|
||||
from keras_core import testing
|
||||
|
||||
if backend.backend() == "jax":
|
||||
from keras_core.backend.jax.trainer import Trainer
|
||||
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
|
||||
else:
|
||||
from keras_core.backend.tensorflow.trainer import Trainer
|
||||
from keras_core.backend.tensorflow.trainer import TensorFlowTrainer as Trainer
|
||||
|
||||
|
||||
# A model is just a layer mixed in with a Trainer.
|
||||
|
Loading…
Reference in New Issue
Block a user