Give trainer classes explicit names
This commit is contained in:
parent
ff097e5b02
commit
6603a1ed46
@ -10,7 +10,7 @@ from keras_core.trainers.data_adapters import data_adapter_utils
|
|||||||
from keras_core.trainers.epoch_iterator import EpochIterator
|
from keras_core.trainers.epoch_iterator import EpochIterator
|
||||||
|
|
||||||
|
|
||||||
class Trainer(base_trainer.Trainer):
|
class JAXTrainer(base_trainer.Trainer):
|
||||||
def compute_loss_and_updates(
|
def compute_loss_and_updates(
|
||||||
self, trainable_variables, non_trainable_variables, x, y, sample_weight
|
self, trainable_variables, non_trainable_variables, x, y, sample_weight
|
||||||
):
|
):
|
||||||
|
@ -12,7 +12,7 @@ from keras_core.trainers.data_adapters import data_adapter_utils
|
|||||||
from keras_core.trainers.epoch_iterator import EpochIterator
|
from keras_core.trainers.epoch_iterator import EpochIterator
|
||||||
|
|
||||||
|
|
||||||
class Trainer(base_trainer.Trainer):
|
class TensorFlowTrainer(base_trainer.Trainer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.train_function = None
|
self.train_function = None
|
||||||
@ -290,7 +290,7 @@ class Trainer(base_trainer.Trainer):
|
|||||||
if use_cached_eval_dataset:
|
if use_cached_eval_dataset:
|
||||||
epoch_iterator = self._eval_epoch_iterator
|
epoch_iterator = self._eval_epoch_iterator
|
||||||
else:
|
else:
|
||||||
# Create an iterator that yields batches for one epoch.
|
# Create an iterator that yields batches of input/target data.
|
||||||
epoch_iterator = TFEpochIterator(
|
epoch_iterator = TFEpochIterator(
|
||||||
x=x,
|
x=x,
|
||||||
y=y,
|
y=y,
|
||||||
|
@ -4,9 +4,9 @@ from keras_core.layers.layer import Layer
|
|||||||
from keras_core.utils import summary_utils
|
from keras_core.utils import summary_utils
|
||||||
|
|
||||||
if backend.backend() == "tensorflow":
|
if backend.backend() == "tensorflow":
|
||||||
from keras_core.backend.tensorflow.trainer import Trainer
|
from keras_core.backend.tensorflow.trainer import TensorFlowTrainer as Trainer
|
||||||
elif backend.backend() == "jax":
|
elif backend.backend() == "jax":
|
||||||
from keras_core.backend.jax.trainer import Trainer
|
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
|
||||||
else:
|
else:
|
||||||
Trainer = None
|
Trainer = None
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ from keras_core import optimizers
|
|||||||
from keras_core import testing
|
from keras_core import testing
|
||||||
|
|
||||||
if backend.backend() == "jax":
|
if backend.backend() == "jax":
|
||||||
from keras_core.backend.jax.trainer import Trainer
|
from keras_core.backend.jax.trainer import JAXTrainer as Trainer
|
||||||
else:
|
else:
|
||||||
from keras_core.backend.tensorflow.trainer import Trainer
|
from keras_core.backend.tensorflow.trainer import TensorFlowTrainer as Trainer
|
||||||
|
|
||||||
|
|
||||||
# A model is just a layer mixed in with a Trainer.
|
# A model is just a layer mixed in with a Trainer.
|
||||||
|
Loading…
Reference in New Issue
Block a user