diff --git a/keras_core/trainers/trainer.py b/keras_core/trainers/trainer.py index 457b27793..cc1bcfd82 100644 --- a/keras_core/trainers/trainer.py +++ b/keras_core/trainers/trainer.py @@ -88,17 +88,18 @@ class Trainer: `keras_core.metrics.BinaryAccuracy`, `keras_core.metrics.CategoricalAccuracy`, `keras_core.metrics.SparseCategoricalAccuracy` based on the - shapes of the targets and of the model output. We do a similar - conversion for the strings 'crossentropy' and 'ce' as well. + shapes of the targets and of the model output. A similar + conversion is done for the strings `"crossentropy"` + and `"ce"` as well. The metrics passed here are evaluated without sample weighting; if you would like sample weighting to apply, you can specify your metrics via the `weighted_metrics` argument instead. weighted_metrics: List of metrics to be evaluated and weighted by `sample_weight` or `class_weight` during training and testing. - run_eagerly: Bool. If `True`, this `Model`'s logic will never be - compiled (e.g. with `tf.function` or `jax.jit`). Recommended to - leave this as `False` when training for best performance, and - `True` when debugging. + run_eagerly: Bool. If `True`, this model's forward pass + will never be compiled. It is recommended to leave this + as `False` when training (for best performance), + and to set it to `True` when debugging. steps_per_execution: Int. The number of batches to run during each a single compiled function call. Running multiple batches inside a single a single compiled function call can @@ -110,15 +111,11 @@ class Trainer: `Callback.on_batch_begin` and `Callback.on_batch_end` methods will only be called every `N` batches (i.e. before/after each compiled function execution). + Not supported with the PyTorch backend. jit_compile: Bool or `"auto"`. Whether to use XLA compilation when - compiling a model. This value should currently never be `True` - on the torch backed, and should always be `True` or `"auto"` on - the jax backend. On tensorflow, this value can be `True` or - `False`, and will toggle the `jit_compile` option for any - `tf.function` owned by the model. See - https://www.tensorflow.org/xla/tutorials/jit_compile for more - details. If `"auto"`, XLA compilation will be enabled if the - backend supports it, and disabled otherwise. + compiling a model. Not supported with the PyTorch backend. + If `"auto"`, XLA compilation will be enabled if the + the model supports it, and disabled otherwise. """ self.optimizer = optimizers.get(optimizer) if hasattr(self, "output_names"):