Add torch and TF custom training loops guides

This commit is contained in:
Francois Chollet 2023-06-26 16:06:00 -07:00
parent 9e9ae0f65e
commit 6ec8a6160c
16 changed files with 950 additions and 63 deletions

@ -1,7 +0,0 @@
try:
# When using torch and tensorflow, torch needs to be imported first,
# otherwise it will segfault upon import. This should force the torch
# import to happen first for all tests.
import torch # noqa: F401
except ImportError:
pass

@ -2,7 +2,7 @@
Title: Making new layers and models via subclassing Title: Making new layers and models via subclassing
Author: [fchollet](https://twitter.com/fchollet) Author: [fchollet](https://twitter.com/fchollet)
Date created: 2019/03/01 Date created: 2019/03/01
Last modified: 2023/06/21 Last modified: 2023/06/25
Description: Complete guide to writing `Layer` and `Model` objects from scratch. Description: Complete guide to writing `Layer` and `Model` objects from scratch.
Accelerator: None Accelerator: None
""" """

@ -2,7 +2,7 @@
Title: The Sequential model Title: The Sequential model
Author: [fchollet](https://twitter.com/fchollet) Author: [fchollet](https://twitter.com/fchollet)
Date created: 2020/04/12 Date created: 2020/04/12
Last modified: 2020/04/12 Last modified: 2023/06/25
Description: Complete guide to the Sequential model. Description: Complete guide to the Sequential model.
Accelerator: GPU Accelerator: GPU
""" """

@ -2,7 +2,7 @@
Title: Training & evaluation with the built-in methods Title: Training & evaluation with the built-in methods
Author: [fchollet](https://twitter.com/fchollet) Author: [fchollet](https://twitter.com/fchollet)
Date created: 2019/03/01 Date created: 2019/03/01
Last modified: 2023/03/20 Last modified: 2023/03/25
Description: Complete guide to training & evaluation with `fit()` and `evaluate()`. Description: Complete guide to training & evaluation with `fit()` and `evaluate()`.
Accelerator: GPU Accelerator: GPU
""" """

@ -2,7 +2,7 @@
Title: Understanding masking & padding Title: Understanding masking & padding
Authors: Scott Zhu, Francois Chollet Authors: Scott Zhu, Francois Chollet
Date created: 2019/07/16 Date created: 2019/07/16
Last modified: 2020/04/14 Last modified: 2023/06/25
Description: Complete guide to using mask-aware sequence layers in Keras. Description: Complete guide to using mask-aware sequence layers in Keras.
Accelerator: None Accelerator: None
""" """
@ -376,5 +376,4 @@ automatically.
manually. manually.
- You can easily write layers that modify the current mask, that generate a new mask, - You can easily write layers that modify the current mask, that generate a new mask,
or that consume the mask associated with the inputs. or that consume the mask associated with the inputs.
""" """

@ -0,0 +1,531 @@
"""
Title: Writing a training loop from scratch in TensorFlow
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2019/03/01
Last modified: 2023/06/25
Description: Writing low-level training & evaluation loops in TensorFlow.
Accelerator: None
"""
"""
## Setup
"""
import time
import os
# This guide can only be run with the TensorFlow backend.
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras_core as keras
import numpy as np
"""
## Introduction
Keras provides default training and evaluation loops, `fit()` and `evaluate()`.
Their usage is covered in the guide
[Training & evaluation with the built-in methods](https://keras.io/guides/training_with_built_in_methods/).
If you want to customize the learning algorithm of your model while still leveraging
the convenience of `fit()`
(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and
implement your own `train_step()` method, which
is called repeatedly during `fit()`.
Now, if you want very low-level control over training & evaluation, you should write
your own training & evaluation loops from scratch. This is what this guide is about.
"""
"""
## A first end-to-end example
Calling a model inside a `GradientTape` scope enables you to retrieve the gradients of
the trainable weights of the layer with respect to a loss value. Using an optimizer
instance, you can use these gradients to update these variables (which you can
retrieve using `model.trainable_weights`).
Let's consider a simple MNIST model:
"""
def get_model():
inputs = keras.Input(shape=(784,), name="digits")
x1 = keras.layers.Dense(64, activation="relu")(inputs)
x2 = keras.layers.Dense(64, activation="relu")(x1)
outputs = keras.layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
model = get_model()
"""
Let's train it using mini-batch gradient with a custom training loop.
First, we're going to need an optimizer, a loss function, and a dataset:
"""
# Instantiate an optimizer.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
"""
Here's our training loop, step by step:
- We open a `for` loop that iterates over epochs
- For each epoch, we open a `for` loop that iterates over the dataset, in batches
- For each batch, we open a `GradientTape()` scope
- Inside this scope, we call the model (forward pass) and compute the loss
- Outside the scope, we retrieve the gradients of the weights
of the model with regard to the loss
- Finally, we use the optimizer to update the weights of the model based on the
gradients
"""
epochs = 3
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# Open a GradientTape to record the operations run
# during the forward pass, which enables auto-differentiation.
with tf.GradientTape() as tape:
# Run the forward pass of the layer.
# The operations that the layer applies
# to its inputs are going to be recorded
# on the GradientTape.
logits = model(
x_batch_train, training=True
) # Logits for this minibatch
# Compute the loss value for this minibatch.
loss_value = loss_fn(y_batch_train, logits)
# Use the gradient tape to automatically retrieve
# the gradients of the trainable variables with respect to the loss.
grads = tape.gradient(loss_value, model.trainable_weights)
# Run one step of gradient descent by updating
# the value of the variables to minimize the loss.
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# Log every 200 batches.
if step % 200 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
"""
## Low-level handling of metrics
Let's add metrics monitoring to this basic loop.
You can readily reuse the built-in metrics (or custom ones you wrote) in such training
loops written from scratch. Here's the flow:
- Instantiate the metric at the start of the loop
- Call `metric.update_state()` after each batch
- Call `metric.result()` when you need to display the current value of the metric
- Call `metric.reset_state()` when you need to clear the state of the metric
(typically at the end of an epoch)
Let's use this knowledge to compute `SparseCategoricalAccuracy` on validation data at
the end of each epoch:
"""
# Get a fresh model
model = get_model()
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()
"""
Here's our training & evaluation loop:
"""
epochs = 2
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
start_time = time.time()
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# Update training metric.
train_acc_metric.update_state(y_batch_train, logits)
# Log every 200 batches.
if step % 200 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print(f"Training acc over epoch: {float(train_acc):.4f}")
# Reset training metrics at the end of each epoch
train_acc_metric.reset_state()
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataset:
val_logits = model(x_batch_val, training=False)
# Update val metrics
val_acc_metric.update_state(y_batch_val, val_logits)
val_acc = val_acc_metric.result()
val_acc_metric.reset_state()
print(f"Validation acc: {float(val_acc):.4f}")
print(f"Time taken: {time.time() - start_time:.2f}s")
"""
## Speeding-up your training step with `tf.function`
The default runtime in TensorFlow is eager execution.
As such, our training loop above executes eagerly.
This is great for debugging, but graph compilation has a definite performance
advantage. Describing your computation as a static graph enables the framework
to apply global performance optimizations. This is impossible when
the framework is constrained to greedily execute one operation after another,
with no knowledge of what comes next.
You can compile into a static graph any function that takes tensors as input.
Just add a `@tf.function` decorator on it, like this:
"""
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
return loss_value
"""
Let's do the same with the evaluation step:
"""
@tf.function
def test_step(x, y):
val_logits = model(x, training=False)
val_acc_metric.update_state(y, val_logits)
"""
Now, let's re-run our training loop with this compiled training step:
"""
epochs = 2
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
start_time = time.time()
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
loss_value = train_step(x_batch_train, y_batch_train)
# Log every 200 batches.
if step % 200 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print(f"Training acc over epoch: {float(train_acc):.4f}")
# Reset training metrics at the end of each epoch
train_acc_metric.reset_state()
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataset:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_acc_metric.reset_state()
print(f"Validation acc: {float(val_acc):.4f}")
print(f"Time taken: {time.time() - start_time:.2f}s")
"""
Much faster, isn't it?
"""
"""
## Low-level handling of losses tracked by the model
Layers & models recursively track any losses created during the forward pass
by layers that call `self.add_loss(value)`. The resulting list of scalar loss
values are available via the property `model.losses`
at the end of the forward pass.
If you want to be using these loss components, you should sum them
and add them to the main loss in your training step.
Consider this layer, that creates an activity regularization loss:
"""
class ActivityRegularizationLayer(keras.layers.Layer):
def call(self, inputs):
self.add_loss(1e-2 * tf.reduce_sum(inputs))
return inputs
"""
Let's build a really simple model that uses it:
"""
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
"""
Here's what our training step should look like now:
"""
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
# Add any extra losses created during the forward pass.
loss_value += sum(model.losses)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
return loss_value
"""
## Summary
Now you know everything there is to know about using built-in training loops and
writing your own from scratch.
To conclude, here's a simple end-to-end example that ties together everything
you've learned in this guide: a DCGAN trained on MNIST digits.
"""
"""
## End-to-end example: a GAN training loop from scratch
You may be familiar with Generative Adversarial Networks (GANs). GANs can generate new
images that look almost real, by learning the latent distribution of a training
dataset of images (the "latent space" of the images).
A GAN is made of two parts: a "generator" model that maps points in the latent
space to points in image space, a "discriminator" model, a classifier
that can tell the difference between real images (from the training dataset)
and fake images (the output of the generator network).
A GAN training loop looks like this:
1) Train the discriminator.
- Sample a batch of random points in the latent space.
- Turn the points into fake images via the "generator" model.
- Get a batch of real images and combine them with the generated images.
- Train the "discriminator" model to classify generated vs. real images.
2) Train the generator.
- Sample random points in the latent space.
- Turn the points into fake images via the "generator" network.
- Get a batch of real images and combine them with the generated images.
- Train the "generator" model to "fool" the discriminator and classify the fake images
as real.
For a much more detailed overview of how GANs works, see
[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python).
Let's implement this training loop. First, create the discriminator meant to classify
fake vs real digits:
"""
discriminator = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
keras.layers.LeakyReLU(negative_slope=0.2),
keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
keras.layers.LeakyReLU(negative_slope=0.2),
keras.layers.GlobalMaxPooling2D(),
keras.layers.Dense(1),
],
name="discriminator",
)
discriminator.summary()
"""
Then let's create a generator network,
that turns latent vectors into outputs of shape `(28, 28, 1)` (representing
MNIST digits):
"""
latent_dim = 128
generator = keras.Sequential(
[
keras.Input(shape=(latent_dim,)),
# We want to generate 128 coefficients to reshape into a 7x7x128 map
keras.layers.Dense(7 * 7 * 128),
keras.layers.LeakyReLU(negative_slope=0.2),
keras.layers.Reshape((7, 7, 128)),
keras.layers.Conv2DTranspose(
128, (4, 4), strides=(2, 2), padding="same"
),
keras.layers.LeakyReLU(negative_slope=0.2),
keras.layers.Conv2DTranspose(
128, (4, 4), strides=(2, 2), padding="same"
),
keras.layers.LeakyReLU(negative_slope=0.2),
keras.layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
],
name="generator",
)
"""
Here's the key bit: the training loop. As you can see it is quite straightforward. The
training step function only takes 17 lines.
"""
# Instantiate one optimizer for the discriminator and another for the generator.
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)
g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)
# Instantiate a loss function.
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
@tf.function
def train_step(real_images):
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
# Decode them to fake images
generated_images = generator(random_latent_vectors)
# Combine them with real images
combined_images = tf.concat([generated_images, real_images], axis=0)
# Assemble labels discriminating real from fake images
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
)
# Add random noise to the labels - important trick!
labels += 0.05 * tf.random.uniform(labels.shape)
# Train the discriminator
with tf.GradientTape() as tape:
predictions = discriminator(combined_images)
d_loss = loss_fn(labels, predictions)
grads = tape.gradient(d_loss, discriminator.trainable_weights)
d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
# Assemble labels that say "all real images"
misleading_labels = tf.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
predictions = discriminator(generator(random_latent_vectors))
g_loss = loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, generator.trainable_weights)
g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
return d_loss, g_loss, generated_images
"""
Let's train our GAN, by repeatedly calling `train_step` on batches of images.
Since our discriminator and generator are convnets, you're going to want to
run this code on a GPU.
"""
# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
epochs = 1 # In practice you need at least 20 epochs to generate nice digits.
save_dir = "./"
for epoch in range(epochs):
print(f"\nStart epoch {epoch}")
for step, real_images in enumerate(dataset):
# Train the discriminator & generator on one batch of real images.
d_loss, g_loss, generated_images = train_step(real_images)
# Logging.
if step % 200 == 0:
# Print metrics
print(f"discriminator loss at step {step}: {d_loss:.2f}")
print(f"adversarial loss at step {step}: {g_loss:.2f}")
# Save one generated image
img = keras.utils.array_to_img(
generated_images[0] * 255.0, scale=False
)
img.save(os.path.join(save_dir, f"generated_img_{step}.png"))
# To limit execution time we stop after 10 steps.
# Remove the lines below to actually train the model!
if step > 10:
break
"""
That's it! You'll get nice-looking fake MNIST digits after just ~30s of training on the
Colab GPU.
"""

@ -0,0 +1,381 @@
"""
Title: Writing a training loop from scratch in PyTorch
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2023/06/25
Last modified: 2023/06/25
Description: Writing low-level training & evaluation loops in PyTorch.
Accelerator: None
"""
"""
## Setup
"""
import os
# This guide can only be run with the torch backend.
os.environ["KERAS_BACKEND"] = "torch"
import torch
import keras_core as keras
import numpy as np
"""
## Introduction
Keras provides default training and evaluation loops, `fit()` and `evaluate()`.
Their usage is covered in the guide
[Training & evaluation with the built-in methods](https://keras.io/guides/training_with_built_in_methods/).
If you want to customize the learning algorithm of your model while still leveraging
the convenience of `fit()`
(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and
implement your own `train_step()` method, which
is called repeatedly during `fit()`.
Now, if you want very low-level control over training & evaluation, you should write
your own training & evaluation loops from scratch. This is what this guide is about.
"""
"""
## A first end-to-end example
To write a custom training loop, you need the following ingredients:
- A model to train, of course.
- An optimizer. You could either use a `keras_core.optimizers` optimizer,
or a native PyTorch optimizer from `torch.optim`.
- A loss function. You could either use a `keras_core.losses` loss,
or a native PyTorch loss from `torch.nn`.
- A dataset. You could use any format: a `tf.data.Dataset`,
a PyTorch `DataLoader`, a Python generator, etc.
Let's line them up. We'll use torch-native objects in each case --
except, of course, for the Keras model.
First, let's get the model and the MNIST dataset:
"""
# Let's consider a simple MNIST model
def get_model():
inputs = keras.Input(shape=(784,), name="digits")
x1 = keras.layers.Dense(64, activation="relu")(inputs)
x2 = keras.layers.Dense(64, activation="relu")(x1)
outputs = keras.layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
# Create load up the MNIST dataset and put it in a torch DataLoader
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Create torch Datasets
train_dataset = torch.utils.data.TensorDataset(
torch.from_numpy(x_train), torch.from_numpy(y_train)
)
val_dataset = torch.utils.data.TensorDataset(
torch.from_numpy(x_val), torch.from_numpy(y_val)
)
# Create DataLoaders for the Datasets
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False
)
"""
Next, here's our PyTorch optimizer and our PyTorch loss function:
"""
# Instantiate a torch optimizer
model = get_model()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Instantiate a torch loss function
loss_fn = torch.nn.CrossEntropyLoss()
"""
Let's train our model using mini-batch gradient with a custom training loop.
Calling `loss.backward()` on a loss tensor triggers backpropagation.
Once that's done, your optimizer is magically aware of the gradients for each variable
and can update its variables, which is done via `optimizer.step()`.
Tensors, variables, optimizers are all interconnected to one another via hidden global state.
Also, don't forget to call `model.zero_grad()` before `loss.backward()`, or you won't
get the right gradients for your variables.
Here's our training loop, step by step:
- We open a `for` loop that iterates over epochs
- For each epoch, we open a `for` loop that iterates over the dataset, in batches
- For each batch, we call the model on the input data to retrive the predictions,
then we use them to compute a loss value
- We call `loss.backward()` to
- Outside the scope, we retrieve the gradients of the weights
of the model with regard to the loss
- Finally, we use the optimizer to update the weights of the model based on the
gradients
"""
epochs = 3
for epoch in range(epochs):
for step, (inputs, targets) in enumerate(train_dataloader):
# Forward pass
logits = model(inputs)
loss = loss_fn(logits, targets)
# Backward pass
model.zero_grad()
loss.backward()
# Optimizer variable updates
optimizer.step()
# Log every 200 batches.
if step % 200 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
"""
As an alternative, let's look at what the loop looks like when using a Keras optimizer
and a Keras loss function.
Important differences:
- You retrieve the gradients for the variables via `v.value.grad`,
called on each trainable variable.
- You update your variables via `optimizer.apply_gradients()`, which must be
called in a `torch.no_grad()` scope.
**Also, a big gotcha:** while all NumPy/TensorFlow/JAX/Keras APIs
as well as Python `unittest` APIs use the argument order convention
`fn(y_true, y_pred)` (reference values first, predicted values second),
PyTorch actually uses `fn(y_pred, y_true)` for its losses.
So make sure to invert the order of `logits` and `targets`.
"""
model = get_model()
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
for step, (inputs, targets) in enumerate(train_dataloader):
# Forward pass
logits = model(inputs)
loss = loss_fn(targets, logits)
# Backward pass
model.zero_grad()
trainable_weights = [v for v in model.trainable_weights]
# Call torch.Tensor.backward() on the loss to compute gradients
# for the weights.
loss.backward()
gradients = [v.value.grad for v in trainable_weights]
# Update weights
with torch.no_grad():
optimizer.apply_gradients(zip(gradients, trainable_weights))
# Log every 200 batches.
if step % 200 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
"""
## Low-level handling of metrics
Let's add metrics monitoring to this basic training loop.
You can readily reuse built-in Keras metrics (or custom ones you wrote) in such training
loops written from scratch. Here's the flow:
- Instantiate the metric at the start of the loop
- Call `metric.update_state()` after each batch
- Call `metric.result()` when you need to display the current value of the metric
- Call `metric.reset_state()` when you need to clear the state of the metric
(typically at the end of an epoch)
Let's use this knowledge to compute `CategoricalAccuracy` on validation data at
the end of each epoch:
"""
# Get a fresh model
model = get_model()
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()
"""
Here's our training & evaluation loop:
"""
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
for step, (inputs, targets) in enumerate(train_dataloader):
# Forward pass
logits = model(inputs)
loss = loss_fn(targets, logits)
# Backward pass
model.zero_grad()
trainable_weights = [v for v in model.trainable_weights]
# Call torch.Tensor.backward() on the loss to compute gradients
# for the weights.
loss.backward()
gradients = [v.value.grad for v in trainable_weights]
# Update weights
with torch.no_grad():
optimizer.apply_gradients(zip(gradients, trainable_weights))
# Update training metric.
train_acc_metric.update_state(targets, logits)
# Log every 200 batches.
if step % 200 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print(f"Training acc over epoch: {float(train_acc):.4f}")
# Reset training metrics at the end of each epoch
train_acc_metric.reset_state()
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataloader:
val_logits = model(x_batch_val, training=False)
# Update val metrics
val_acc_metric.update_state(y_batch_val, val_logits)
val_acc = val_acc_metric.result()
val_acc_metric.reset_state()
print(f"Validation acc: {float(val_acc):.4f}")
"""
## Low-level handling of losses tracked by the model
Layers & models recursively track any losses created during the forward pass
by layers that call `self.add_loss(value)`. The resulting list of scalar loss
values are available via the property `model.losses`
at the end of the forward pass.
If you want to be using these loss components, you should sum them
and add them to the main loss in your training step.
Consider this layer, that creates an activity regularization loss:
"""
class ActivityRegularizationLayer(keras.layers.Layer):
def call(self, inputs):
self.add_loss(1e-2 * torch.sum(inputs))
return inputs
"""
Let's build a really simple model that uses it:
"""
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
"""
Here's what our training loop should look like now:
"""
# Get a fresh model
model = get_model()
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
for step, (inputs, targets) in enumerate(train_dataloader):
# Forward pass
logits = model(inputs)
loss = loss_fn(targets, logits)
if model.losses:
loss = loss + torch.sum(*model.losses)
# Backward pass
model.zero_grad()
trainable_weights = [v for v in model.trainable_weights]
# Call torch.Tensor.backward() on the loss to compute gradients
# for the weights.
loss.backward()
gradients = [v.value.grad for v in trainable_weights]
# Update weights
with torch.no_grad():
optimizer.apply_gradients(zip(gradients, trainable_weights))
# Update training metric.
train_acc_metric.update_state(targets, logits)
# Log every 200 batches.
if step % 200 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print(f"Training acc over epoch: {float(train_acc):.4f}")
# Reset training metrics at the end of each epoch
train_acc_metric.reset_state()
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataloader:
val_logits = model(x_batch_val, training=False)
# Update val metrics
val_acc_metric.update_state(y_batch_val, val_logits)
val_acc = val_acc_metric.result()
val_acc_metric.reset_state()
print(f"Validation acc: {float(val_acc):.4f}")

@ -2,7 +2,7 @@
Title: Writing your own callbacks Title: Writing your own callbacks
Authors: Rick Chao, Francois Chollet Authors: Rick Chao, Francois Chollet
Date created: 2019/03/20 Date created: 2019/03/20
Last modified: 2020/07/12 Last modified: 2023/06/25
Description: Complete guide to writing new Keras callbacks. Description: Complete guide to writing new Keras callbacks.
Accelerator: GPU Accelerator: GPU
""" """

@ -28,7 +28,7 @@ class Variable(
) )
def _direct_assign(self, value): def _direct_assign(self, value):
self._value.assign(tf.cast(value, self._value.dtype)) self.value.assign(value)
def _convert_to_tensor(self, value, dtype=None): def _convert_to_tensor(self, value, dtype=None):
return convert_to_tensor(value, dtype=dtype) return convert_to_tensor(value, dtype=dtype)

@ -129,9 +129,9 @@ def convert_to_tensor(x, dtype=None):
# Convert to np in case of any array-like that is not list or tuple. # Convert to np in case of any array-like that is not list or tuple.
if not isinstance(x, (list, tuple)): if not isinstance(x, (list, tuple)):
x = np.array(x) x = np.array(x)
elif len(x) > 0 and any(isinstance(x1, torch.Tensor) for x1 in x): elif len(x) > 0 and isinstance(x[0], torch.Tensor):
# Handle list or tuple of torch tensors # Handle list or tuple of torch tensors
return torch.stack([convert_to_tensor(x1) for x1 in x]) return torch.stack(x)
if isinstance(x, np.ndarray) and x.dtype == np.uint32: if isinstance(x, np.ndarray) and x.dtype == np.uint32:
# Torch backend does not support uint32. # Torch backend does not support uint32.
x = x.astype(np.int64) x = x.astype(np.int64)

@ -239,9 +239,5 @@ class CoreOpsCorrectnessTest(testing.TestCase):
self.assertAllEqual(x, (1, 1)) self.assertAllEqual(x, (1, 1))
self.assertIsInstance(x, np.ndarray) self.assertIsInstance(x, np.ndarray)
# Partially converted.
x = ops.convert_to_tensor((1, ops.array(2), 3))
self.assertAllEqual(x, (1, 2, 3))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
ops.convert_to_numpy(KerasTensor((2,))) ops.convert_to_numpy(KerasTensor((2,)))

@ -5,6 +5,7 @@ import inspect
import types import types
import warnings import warnings
import jax
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -161,7 +162,11 @@ def serialize_keras_object(obj):
} }
if isinstance(obj, tf.TensorShape): if isinstance(obj, tf.TensorShape):
return obj.as_list() if obj._dims is not None else None return obj.as_list() if obj._dims is not None else None
if backend.is_tensor(obj): if isinstance(obj, (tf.Tensor, jax.numpy.ndarray)) or hasattr(
obj, "device"
):
# Import torch creates circular dependency, so we use
# `hasattr(obj, "device")` to check if obj is a torch tensor.
return { return {
"class_name": "__tensor__", "class_name": "__tensor__",
"config": { "config": {

@ -27,7 +27,11 @@ class ArrayDataAdapter(DataAdapter):
shuffle=False, shuffle=False,
class_weight=None, class_weight=None,
): ):
if not can_convert_arrays((x, y, sample_weight)): types_struct = nest.map_structure(lambda x: type(x), x)
flat_types = nest.flatten(types_struct)
if not all(
issubclass(c, data_adapter_utils.ARRAY_TYPES) for c in flat_types
):
raise ValueError( raise ValueError(
"Expected all elements of `x` to be array-like. " "Expected all elements of `x` to be array-like. "
f"Received invalid types: x={x}" f"Received invalid types: x={x}"
@ -248,28 +252,6 @@ class ArrayDataAdapter(DataAdapter):
return self._partial_batch_size or None return self._partial_batch_size or None
def can_convert_arrays(arrays):
"""Check if array like-inputs can be handled by `ArrayDataAdapter`
Args:
inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like.
Returns:
`True` if `arrays` can be handled by `ArrayDataAdapter`, `False`
otherwise.
"""
def can_convert_single_array(x):
is_none = x is None
known_type = isinstance(x, data_adapter_utils.ARRAY_TYPES)
convertable_type = hasattr(x, "__array__")
return is_none or known_type or convertable_type
return all(
tf.nest.flatten(tf.nest.map_structure(can_convert_single_array, arrays))
)
def convert_to_arrays(arrays, dtype=None): def convert_to_arrays(arrays, dtype=None):
"""Process array-like inputs. """Process array-like inputs.
@ -280,7 +262,7 @@ def convert_to_arrays(arrays, dtype=None):
- Converts `list`s to `tuple`s (for `tf.data` support). - Converts `list`s to `tuple`s (for `tf.data` support).
Args: Args:
inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like. inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like.
Returns: Returns:
Structure of NumPy `ndarray`s. Structure of NumPy `ndarray`s.
@ -295,16 +277,15 @@ def convert_to_arrays(arrays, dtype=None):
x = np.expand_dims(x.to_numpy(dtype=dtype), axis=-1) x = np.expand_dims(x.to_numpy(dtype=dtype), axis=-1)
elif isinstance(x, pandas.DataFrame): elif isinstance(x, pandas.DataFrame):
x = x.to_numpy(dtype=dtype) x = x.to_numpy(dtype=dtype)
if isinstance(x, (tf.Tensor, tf.Variable)):
x = x.numpy()
if not isinstance(x, np.ndarray): if not isinstance(x, np.ndarray):
# Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`,
# `torch.Tensor`, as well as any other tensor-like object that has
# added numpy support.
if hasattr(x, "__array__"): if hasattr(x, "__array__"):
x = np.array(x, dtype=dtype) x = np.array(x, dtype=dtype)
else: else:
raise ValueError( raise ValueError(
"Expected a NumPy array, tf.Tensor, jax.np.ndarray, " "Expected a NumPy array, tf.Tensor, "
"torch.Tensor, Pandas Dataframe, or Pandas Series. " "Pandas Dataframe, or Pandas Series. "
f"Received invalid input: {x} (of type {type(x)})" f"Received invalid input: {x} (of type {type(x)})"
) )
if x.dtype == object: if x.dtype == object:

@ -2,7 +2,6 @@ import jax
import numpy as np import numpy as np
import pandas import pandas
import tensorflow as tf import tensorflow as tf
import torch
from absl.testing import parameterized from absl.testing import parameterized
from keras_core import backend from keras_core import backend
@ -11,9 +10,7 @@ from keras_core.trainers.data_adapters import array_data_adapter
class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase): class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters([("np",), ("tf",), ("pandas")])
[("np",), ("tf",), ("jax",), ("torch",), ("pandas")]
)
def test_basic_flow(self, array_type): def test_basic_flow(self, array_type):
if array_type == "np": if array_type == "np":
x = np.random.random((34, 4)) x = np.random.random((34, 4))
@ -24,9 +21,6 @@ class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase):
elif array_type == "jax": elif array_type == "jax":
x = jax.numpy.ones((34, 4)) x = jax.numpy.ones((34, 4))
y = jax.numpy.ones((34, 2)) y = jax.numpy.ones((34, 2))
elif array_type == "torch":
x = torch.ones((34, 4))
y = torch.ones((34, 2))
elif array_type == "pandas": elif array_type == "pandas":
x = pandas.DataFrame(np.random.random((34, 4))) x = pandas.DataFrame(np.random.random((34, 4)))
y = pandas.DataFrame(np.random.random((34, 2))) y = pandas.DataFrame(np.random.random((34, 2)))

@ -1,5 +1,6 @@
import math import math
import jax
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@ -11,12 +12,15 @@ except ImportError:
pandas = None pandas = None
# Leave jax, tf, and torch arrays off this list. Instead we will use ARRAY_TYPES = (tf.Tensor, np.ndarray, jax.numpy.ndarray)
# `__array__` to detect these types. Doing so allows us to avoid importing a
# backend framework we are not currently using just to do type-checking.
ARRAY_TYPES = (np.ndarray,)
if pandas: if pandas:
ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame) ARRAY_TYPES = ARRAY_TYPES + (
tf.Tensor,
np.ndarray,
pandas.Series,
pandas.DataFrame,
)
# TODO: support torch tensors?
@keras_core_export("keras_core.utils.unpack_x_y_sample_weight") @keras_core_export("keras_core.utils.unpack_x_y_sample_weight")

@ -42,8 +42,10 @@ import types
import warnings import warnings
import tensorflow as tf import tensorflow as tf
from tensorflow import nest
from keras_core.trainers.data_adapters import array_data_adapter from keras_core.trainers.data_adapters import array_data_adapter
from keras_core.trainers.data_adapters import data_adapter_utils
from keras_core.trainers.data_adapters import generator_data_adapter from keras_core.trainers.data_adapters import generator_data_adapter
from keras_core.trainers.data_adapters import py_dataset_adapter from keras_core.trainers.data_adapters import py_dataset_adapter
from keras_core.trainers.data_adapters import tf_dataset_adapter from keras_core.trainers.data_adapters import tf_dataset_adapter
@ -67,7 +69,8 @@ class EpochIterator:
if steps_per_epoch: if steps_per_epoch:
self._current_iterator = None self._current_iterator = None
self._insufficient_data = False self._insufficient_data = False
if array_data_adapter.can_convert_arrays((x, y, sample_weight)): first_element = next(iter(nest.flatten(x)), None)
if isinstance(first_element, data_adapter_utils.ARRAY_TYPES):
self.data_adapter = array_data_adapter.ArrayDataAdapter( self.data_adapter = array_data_adapter.ArrayDataAdapter(
x, x,
y, y,