eabdb87f9f
* Add numpy ops (initial batch) and some config * Add unit test * fix call * Revert "fix call" This reverts commit 6748ad183029ff4b97317b77ceed8661916bb9a0. * full unit test coverage * fix setup.py
24 lines
756 B
Python
24 lines
756 B
Python
import tensorflow as tf
|
|
|
|
from keras_core.trainers import trainer
|
|
|
|
|
|
class Trainer(trainer.Trainer):
|
|
def train_step(self, data):
|
|
# Unpack the data. Its structure depends on your model and
|
|
# on what you pass to `fit()`.
|
|
x, y = data
|
|
|
|
with tf.GradientTape() as tape:
|
|
y_pred = self(x, training=True) # Forward pass
|
|
# Compute the loss value
|
|
# (the loss function is configured in `compile()`)
|
|
loss = self.loss(y, y_pred)
|
|
|
|
# Compute gradients
|
|
trainable_weights = self.trainable_weights
|
|
gradients = tape.gradient(loss, trainable_weights)
|
|
# Update weights
|
|
self.optimizer.apply_gradients(zip(gradients, trainable_weights))
|
|
return loss
|