keras/keras_core/backend/tensorflow/trainer.py
Chen Qian eabdb87f9f Add some numpy ops (#1)
* Add numpy ops (initial batch) and some config

* Add unit test

* fix call

* Revert "fix call"

This reverts commit 6748ad183029ff4b97317b77ceed8661916bb9a0.

* full unit test coverage

* fix setup.py
2023-04-12 11:31:58 -07:00

24 lines
756 B
Python

import tensorflow as tf
from keras_core.trainers import trainer
class Trainer(trainer.Trainer):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.loss(y, y_pred)
# Compute gradients
trainable_weights = self.trainable_weights
gradients = tape.gradient(loss, trainable_weights)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_weights))
return loss