Add torch and TF custom training loops guides
This commit is contained in:
parent
9e9ae0f65e
commit
6ec8a6160c
@ -1,7 +0,0 @@
|
|||||||
try:
|
|
||||||
# When using torch and tensorflow, torch needs to be imported first,
|
|
||||||
# otherwise it will segfault upon import. This should force the torch
|
|
||||||
# import to happen first for all tests.
|
|
||||||
import torch # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
@ -2,7 +2,7 @@
|
|||||||
Title: Making new layers and models via subclassing
|
Title: Making new layers and models via subclassing
|
||||||
Author: [fchollet](https://twitter.com/fchollet)
|
Author: [fchollet](https://twitter.com/fchollet)
|
||||||
Date created: 2019/03/01
|
Date created: 2019/03/01
|
||||||
Last modified: 2023/06/21
|
Last modified: 2023/06/25
|
||||||
Description: Complete guide to writing `Layer` and `Model` objects from scratch.
|
Description: Complete guide to writing `Layer` and `Model` objects from scratch.
|
||||||
Accelerator: None
|
Accelerator: None
|
||||||
"""
|
"""
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
Title: The Sequential model
|
Title: The Sequential model
|
||||||
Author: [fchollet](https://twitter.com/fchollet)
|
Author: [fchollet](https://twitter.com/fchollet)
|
||||||
Date created: 2020/04/12
|
Date created: 2020/04/12
|
||||||
Last modified: 2020/04/12
|
Last modified: 2023/06/25
|
||||||
Description: Complete guide to the Sequential model.
|
Description: Complete guide to the Sequential model.
|
||||||
Accelerator: GPU
|
Accelerator: GPU
|
||||||
"""
|
"""
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
Title: Training & evaluation with the built-in methods
|
Title: Training & evaluation with the built-in methods
|
||||||
Author: [fchollet](https://twitter.com/fchollet)
|
Author: [fchollet](https://twitter.com/fchollet)
|
||||||
Date created: 2019/03/01
|
Date created: 2019/03/01
|
||||||
Last modified: 2023/03/20
|
Last modified: 2023/03/25
|
||||||
Description: Complete guide to training & evaluation with `fit()` and `evaluate()`.
|
Description: Complete guide to training & evaluation with `fit()` and `evaluate()`.
|
||||||
Accelerator: GPU
|
Accelerator: GPU
|
||||||
"""
|
"""
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
Title: Understanding masking & padding
|
Title: Understanding masking & padding
|
||||||
Authors: Scott Zhu, Francois Chollet
|
Authors: Scott Zhu, Francois Chollet
|
||||||
Date created: 2019/07/16
|
Date created: 2019/07/16
|
||||||
Last modified: 2020/04/14
|
Last modified: 2023/06/25
|
||||||
Description: Complete guide to using mask-aware sequence layers in Keras.
|
Description: Complete guide to using mask-aware sequence layers in Keras.
|
||||||
Accelerator: None
|
Accelerator: None
|
||||||
"""
|
"""
|
||||||
@ -376,5 +376,4 @@ automatically.
|
|||||||
manually.
|
manually.
|
||||||
- You can easily write layers that modify the current mask, that generate a new mask,
|
- You can easily write layers that modify the current mask, that generate a new mask,
|
||||||
or that consume the mask associated with the inputs.
|
or that consume the mask associated with the inputs.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
531
guides/writing_a_custom_training_loop_in_tensorflow.py
Normal file
531
guides/writing_a_custom_training_loop_in_tensorflow.py
Normal file
@ -0,0 +1,531 @@
|
|||||||
|
"""
|
||||||
|
Title: Writing a training loop from scratch in TensorFlow
|
||||||
|
Author: [fchollet](https://twitter.com/fchollet)
|
||||||
|
Date created: 2019/03/01
|
||||||
|
Last modified: 2023/06/25
|
||||||
|
Description: Writing low-level training & evaluation loops in TensorFlow.
|
||||||
|
Accelerator: None
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
## Setup
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
|
# This guide can only be run with the TensorFlow backend.
|
||||||
|
os.environ["KERAS_BACKEND"] = "tensorflow"
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import keras_core as keras
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
Keras provides default training and evaluation loops, `fit()` and `evaluate()`.
|
||||||
|
Their usage is covered in the guide
|
||||||
|
[Training & evaluation with the built-in methods](https://keras.io/guides/training_with_built_in_methods/).
|
||||||
|
|
||||||
|
If you want to customize the learning algorithm of your model while still leveraging
|
||||||
|
the convenience of `fit()`
|
||||||
|
(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and
|
||||||
|
implement your own `train_step()` method, which
|
||||||
|
is called repeatedly during `fit()`.
|
||||||
|
|
||||||
|
Now, if you want very low-level control over training & evaluation, you should write
|
||||||
|
your own training & evaluation loops from scratch. This is what this guide is about.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
## A first end-to-end example
|
||||||
|
|
||||||
|
Calling a model inside a `GradientTape` scope enables you to retrieve the gradients of
|
||||||
|
the trainable weights of the layer with respect to a loss value. Using an optimizer
|
||||||
|
instance, you can use these gradients to update these variables (which you can
|
||||||
|
retrieve using `model.trainable_weights`).
|
||||||
|
|
||||||
|
Let's consider a simple MNIST model:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_model():
|
||||||
|
inputs = keras.Input(shape=(784,), name="digits")
|
||||||
|
x1 = keras.layers.Dense(64, activation="relu")(inputs)
|
||||||
|
x2 = keras.layers.Dense(64, activation="relu")(x1)
|
||||||
|
outputs = keras.layers.Dense(10, name="predictions")(x2)
|
||||||
|
model = keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
model = get_model()
|
||||||
|
|
||||||
|
"""
|
||||||
|
Let's train it using mini-batch gradient with a custom training loop.
|
||||||
|
|
||||||
|
First, we're going to need an optimizer, a loss function, and a dataset:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Instantiate an optimizer.
|
||||||
|
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
|
||||||
|
# Instantiate a loss function.
|
||||||
|
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||||
|
|
||||||
|
# Prepare the training dataset.
|
||||||
|
batch_size = 32
|
||||||
|
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
|
||||||
|
x_train = np.reshape(x_train, (-1, 784))
|
||||||
|
x_test = np.reshape(x_test, (-1, 784))
|
||||||
|
|
||||||
|
# Reserve 10,000 samples for validation.
|
||||||
|
x_val = x_train[-10000:]
|
||||||
|
y_val = y_train[-10000:]
|
||||||
|
x_train = x_train[:-10000]
|
||||||
|
y_train = y_train[:-10000]
|
||||||
|
|
||||||
|
# Prepare the training dataset.
|
||||||
|
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
||||||
|
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
|
||||||
|
|
||||||
|
# Prepare the validation dataset.
|
||||||
|
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
|
||||||
|
val_dataset = val_dataset.batch(batch_size)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Here's our training loop, step by step:
|
||||||
|
|
||||||
|
- We open a `for` loop that iterates over epochs
|
||||||
|
- For each epoch, we open a `for` loop that iterates over the dataset, in batches
|
||||||
|
- For each batch, we open a `GradientTape()` scope
|
||||||
|
- Inside this scope, we call the model (forward pass) and compute the loss
|
||||||
|
- Outside the scope, we retrieve the gradients of the weights
|
||||||
|
of the model with regard to the loss
|
||||||
|
- Finally, we use the optimizer to update the weights of the model based on the
|
||||||
|
gradients
|
||||||
|
"""
|
||||||
|
|
||||||
|
epochs = 3
|
||||||
|
for epoch in range(epochs):
|
||||||
|
print(f"\nStart of epoch {epoch}")
|
||||||
|
|
||||||
|
# Iterate over the batches of the dataset.
|
||||||
|
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
|
||||||
|
# Open a GradientTape to record the operations run
|
||||||
|
# during the forward pass, which enables auto-differentiation.
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
# Run the forward pass of the layer.
|
||||||
|
# The operations that the layer applies
|
||||||
|
# to its inputs are going to be recorded
|
||||||
|
# on the GradientTape.
|
||||||
|
logits = model(
|
||||||
|
x_batch_train, training=True
|
||||||
|
) # Logits for this minibatch
|
||||||
|
|
||||||
|
# Compute the loss value for this minibatch.
|
||||||
|
loss_value = loss_fn(y_batch_train, logits)
|
||||||
|
|
||||||
|
# Use the gradient tape to automatically retrieve
|
||||||
|
# the gradients of the trainable variables with respect to the loss.
|
||||||
|
grads = tape.gradient(loss_value, model.trainable_weights)
|
||||||
|
|
||||||
|
# Run one step of gradient descent by updating
|
||||||
|
# the value of the variables to minimize the loss.
|
||||||
|
optimizer.apply_gradients(zip(grads, model.trainable_weights))
|
||||||
|
|
||||||
|
# Log every 200 batches.
|
||||||
|
if step % 200 == 0:
|
||||||
|
print(
|
||||||
|
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
|
||||||
|
)
|
||||||
|
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Low-level handling of metrics
|
||||||
|
|
||||||
|
Let's add metrics monitoring to this basic loop.
|
||||||
|
|
||||||
|
You can readily reuse the built-in metrics (or custom ones you wrote) in such training
|
||||||
|
loops written from scratch. Here's the flow:
|
||||||
|
|
||||||
|
- Instantiate the metric at the start of the loop
|
||||||
|
- Call `metric.update_state()` after each batch
|
||||||
|
- Call `metric.result()` when you need to display the current value of the metric
|
||||||
|
- Call `metric.reset_state()` when you need to clear the state of the metric
|
||||||
|
(typically at the end of an epoch)
|
||||||
|
|
||||||
|
Let's use this knowledge to compute `SparseCategoricalAccuracy` on validation data at
|
||||||
|
the end of each epoch:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get a fresh model
|
||||||
|
model = get_model()
|
||||||
|
|
||||||
|
# Instantiate an optimizer to train the model.
|
||||||
|
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
|
||||||
|
# Instantiate a loss function.
|
||||||
|
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||||
|
|
||||||
|
# Prepare the metrics.
|
||||||
|
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
|
||||||
|
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()
|
||||||
|
|
||||||
|
"""
|
||||||
|
Here's our training & evaluation loop:
|
||||||
|
"""
|
||||||
|
|
||||||
|
epochs = 2
|
||||||
|
for epoch in range(epochs):
|
||||||
|
print(f"\nStart of epoch {epoch}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Iterate over the batches of the dataset.
|
||||||
|
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
logits = model(x_batch_train, training=True)
|
||||||
|
loss_value = loss_fn(y_batch_train, logits)
|
||||||
|
grads = tape.gradient(loss_value, model.trainable_weights)
|
||||||
|
optimizer.apply_gradients(zip(grads, model.trainable_weights))
|
||||||
|
|
||||||
|
# Update training metric.
|
||||||
|
train_acc_metric.update_state(y_batch_train, logits)
|
||||||
|
|
||||||
|
# Log every 200 batches.
|
||||||
|
if step % 200 == 0:
|
||||||
|
print(
|
||||||
|
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
|
||||||
|
)
|
||||||
|
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||||
|
|
||||||
|
# Display metrics at the end of each epoch.
|
||||||
|
train_acc = train_acc_metric.result()
|
||||||
|
print(f"Training acc over epoch: {float(train_acc):.4f}")
|
||||||
|
|
||||||
|
# Reset training metrics at the end of each epoch
|
||||||
|
train_acc_metric.reset_state()
|
||||||
|
|
||||||
|
# Run a validation loop at the end of each epoch.
|
||||||
|
for x_batch_val, y_batch_val in val_dataset:
|
||||||
|
val_logits = model(x_batch_val, training=False)
|
||||||
|
# Update val metrics
|
||||||
|
val_acc_metric.update_state(y_batch_val, val_logits)
|
||||||
|
val_acc = val_acc_metric.result()
|
||||||
|
val_acc_metric.reset_state()
|
||||||
|
print(f"Validation acc: {float(val_acc):.4f}")
|
||||||
|
print(f"Time taken: {time.time() - start_time:.2f}s")
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Speeding-up your training step with `tf.function`
|
||||||
|
|
||||||
|
The default runtime in TensorFlow is eager execution.
|
||||||
|
As such, our training loop above executes eagerly.
|
||||||
|
|
||||||
|
This is great for debugging, but graph compilation has a definite performance
|
||||||
|
advantage. Describing your computation as a static graph enables the framework
|
||||||
|
to apply global performance optimizations. This is impossible when
|
||||||
|
the framework is constrained to greedily execute one operation after another,
|
||||||
|
with no knowledge of what comes next.
|
||||||
|
|
||||||
|
You can compile into a static graph any function that takes tensors as input.
|
||||||
|
Just add a `@tf.function` decorator on it, like this:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def train_step(x, y):
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
logits = model(x, training=True)
|
||||||
|
loss_value = loss_fn(y, logits)
|
||||||
|
grads = tape.gradient(loss_value, model.trainable_weights)
|
||||||
|
optimizer.apply_gradients(zip(grads, model.trainable_weights))
|
||||||
|
train_acc_metric.update_state(y, logits)
|
||||||
|
return loss_value
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Let's do the same with the evaluation step:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def test_step(x, y):
|
||||||
|
val_logits = model(x, training=False)
|
||||||
|
val_acc_metric.update_state(y, val_logits)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Now, let's re-run our training loop with this compiled training step:
|
||||||
|
"""
|
||||||
|
|
||||||
|
epochs = 2
|
||||||
|
for epoch in range(epochs):
|
||||||
|
print(f"\nStart of epoch {epoch}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Iterate over the batches of the dataset.
|
||||||
|
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
|
||||||
|
loss_value = train_step(x_batch_train, y_batch_train)
|
||||||
|
|
||||||
|
# Log every 200 batches.
|
||||||
|
if step % 200 == 0:
|
||||||
|
print(
|
||||||
|
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
|
||||||
|
)
|
||||||
|
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||||
|
|
||||||
|
# Display metrics at the end of each epoch.
|
||||||
|
train_acc = train_acc_metric.result()
|
||||||
|
print(f"Training acc over epoch: {float(train_acc):.4f}")
|
||||||
|
|
||||||
|
# Reset training metrics at the end of each epoch
|
||||||
|
train_acc_metric.reset_state()
|
||||||
|
|
||||||
|
# Run a validation loop at the end of each epoch.
|
||||||
|
for x_batch_val, y_batch_val in val_dataset:
|
||||||
|
test_step(x_batch_val, y_batch_val)
|
||||||
|
|
||||||
|
val_acc = val_acc_metric.result()
|
||||||
|
val_acc_metric.reset_state()
|
||||||
|
print(f"Validation acc: {float(val_acc):.4f}")
|
||||||
|
print(f"Time taken: {time.time() - start_time:.2f}s")
|
||||||
|
|
||||||
|
"""
|
||||||
|
Much faster, isn't it?
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Low-level handling of losses tracked by the model
|
||||||
|
|
||||||
|
Layers & models recursively track any losses created during the forward pass
|
||||||
|
by layers that call `self.add_loss(value)`. The resulting list of scalar loss
|
||||||
|
values are available via the property `model.losses`
|
||||||
|
at the end of the forward pass.
|
||||||
|
|
||||||
|
If you want to be using these loss components, you should sum them
|
||||||
|
and add them to the main loss in your training step.
|
||||||
|
|
||||||
|
Consider this layer, that creates an activity regularization loss:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ActivityRegularizationLayer(keras.layers.Layer):
|
||||||
|
def call(self, inputs):
|
||||||
|
self.add_loss(1e-2 * tf.reduce_sum(inputs))
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Let's build a really simple model that uses it:
|
||||||
|
"""
|
||||||
|
|
||||||
|
inputs = keras.Input(shape=(784,), name="digits")
|
||||||
|
x = keras.layers.Dense(64, activation="relu")(inputs)
|
||||||
|
# Insert activity regularization as a layer
|
||||||
|
x = ActivityRegularizationLayer()(x)
|
||||||
|
x = keras.layers.Dense(64, activation="relu")(x)
|
||||||
|
outputs = keras.layers.Dense(10, name="predictions")(x)
|
||||||
|
|
||||||
|
model = keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Here's what our training step should look like now:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def train_step(x, y):
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
logits = model(x, training=True)
|
||||||
|
loss_value = loss_fn(y, logits)
|
||||||
|
# Add any extra losses created during the forward pass.
|
||||||
|
loss_value += sum(model.losses)
|
||||||
|
grads = tape.gradient(loss_value, model.trainable_weights)
|
||||||
|
optimizer.apply_gradients(zip(grads, model.trainable_weights))
|
||||||
|
train_acc_metric.update_state(y, logits)
|
||||||
|
return loss_value
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
Now you know everything there is to know about using built-in training loops and
|
||||||
|
writing your own from scratch.
|
||||||
|
|
||||||
|
To conclude, here's a simple end-to-end example that ties together everything
|
||||||
|
you've learned in this guide: a DCGAN trained on MNIST digits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
## End-to-end example: a GAN training loop from scratch
|
||||||
|
|
||||||
|
You may be familiar with Generative Adversarial Networks (GANs). GANs can generate new
|
||||||
|
images that look almost real, by learning the latent distribution of a training
|
||||||
|
dataset of images (the "latent space" of the images).
|
||||||
|
|
||||||
|
A GAN is made of two parts: a "generator" model that maps points in the latent
|
||||||
|
space to points in image space, a "discriminator" model, a classifier
|
||||||
|
that can tell the difference between real images (from the training dataset)
|
||||||
|
and fake images (the output of the generator network).
|
||||||
|
|
||||||
|
A GAN training loop looks like this:
|
||||||
|
|
||||||
|
1) Train the discriminator.
|
||||||
|
- Sample a batch of random points in the latent space.
|
||||||
|
- Turn the points into fake images via the "generator" model.
|
||||||
|
- Get a batch of real images and combine them with the generated images.
|
||||||
|
- Train the "discriminator" model to classify generated vs. real images.
|
||||||
|
|
||||||
|
2) Train the generator.
|
||||||
|
- Sample random points in the latent space.
|
||||||
|
- Turn the points into fake images via the "generator" network.
|
||||||
|
- Get a batch of real images and combine them with the generated images.
|
||||||
|
- Train the "generator" model to "fool" the discriminator and classify the fake images
|
||||||
|
as real.
|
||||||
|
|
||||||
|
For a much more detailed overview of how GANs works, see
|
||||||
|
[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python).
|
||||||
|
|
||||||
|
Let's implement this training loop. First, create the discriminator meant to classify
|
||||||
|
fake vs real digits:
|
||||||
|
"""
|
||||||
|
|
||||||
|
discriminator = keras.Sequential(
|
||||||
|
[
|
||||||
|
keras.Input(shape=(28, 28, 1)),
|
||||||
|
keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
|
||||||
|
keras.layers.LeakyReLU(negative_slope=0.2),
|
||||||
|
keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
|
||||||
|
keras.layers.LeakyReLU(negative_slope=0.2),
|
||||||
|
keras.layers.GlobalMaxPooling2D(),
|
||||||
|
keras.layers.Dense(1),
|
||||||
|
],
|
||||||
|
name="discriminator",
|
||||||
|
)
|
||||||
|
discriminator.summary()
|
||||||
|
|
||||||
|
"""
|
||||||
|
Then let's create a generator network,
|
||||||
|
that turns latent vectors into outputs of shape `(28, 28, 1)` (representing
|
||||||
|
MNIST digits):
|
||||||
|
"""
|
||||||
|
|
||||||
|
latent_dim = 128
|
||||||
|
|
||||||
|
generator = keras.Sequential(
|
||||||
|
[
|
||||||
|
keras.Input(shape=(latent_dim,)),
|
||||||
|
# We want to generate 128 coefficients to reshape into a 7x7x128 map
|
||||||
|
keras.layers.Dense(7 * 7 * 128),
|
||||||
|
keras.layers.LeakyReLU(negative_slope=0.2),
|
||||||
|
keras.layers.Reshape((7, 7, 128)),
|
||||||
|
keras.layers.Conv2DTranspose(
|
||||||
|
128, (4, 4), strides=(2, 2), padding="same"
|
||||||
|
),
|
||||||
|
keras.layers.LeakyReLU(negative_slope=0.2),
|
||||||
|
keras.layers.Conv2DTranspose(
|
||||||
|
128, (4, 4), strides=(2, 2), padding="same"
|
||||||
|
),
|
||||||
|
keras.layers.LeakyReLU(negative_slope=0.2),
|
||||||
|
keras.layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
|
||||||
|
],
|
||||||
|
name="generator",
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Here's the key bit: the training loop. As you can see it is quite straightforward. The
|
||||||
|
training step function only takes 17 lines.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Instantiate one optimizer for the discriminator and another for the generator.
|
||||||
|
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)
|
||||||
|
g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)
|
||||||
|
|
||||||
|
# Instantiate a loss function.
|
||||||
|
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
|
||||||
|
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def train_step(real_images):
|
||||||
|
# Sample random points in the latent space
|
||||||
|
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
|
||||||
|
# Decode them to fake images
|
||||||
|
generated_images = generator(random_latent_vectors)
|
||||||
|
# Combine them with real images
|
||||||
|
combined_images = tf.concat([generated_images, real_images], axis=0)
|
||||||
|
|
||||||
|
# Assemble labels discriminating real from fake images
|
||||||
|
labels = tf.concat(
|
||||||
|
[tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
|
||||||
|
)
|
||||||
|
# Add random noise to the labels - important trick!
|
||||||
|
labels += 0.05 * tf.random.uniform(labels.shape)
|
||||||
|
|
||||||
|
# Train the discriminator
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
predictions = discriminator(combined_images)
|
||||||
|
d_loss = loss_fn(labels, predictions)
|
||||||
|
grads = tape.gradient(d_loss, discriminator.trainable_weights)
|
||||||
|
d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))
|
||||||
|
|
||||||
|
# Sample random points in the latent space
|
||||||
|
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
|
||||||
|
# Assemble labels that say "all real images"
|
||||||
|
misleading_labels = tf.zeros((batch_size, 1))
|
||||||
|
|
||||||
|
# Train the generator (note that we should *not* update the weights
|
||||||
|
# of the discriminator)!
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
predictions = discriminator(generator(random_latent_vectors))
|
||||||
|
g_loss = loss_fn(misleading_labels, predictions)
|
||||||
|
grads = tape.gradient(g_loss, generator.trainable_weights)
|
||||||
|
g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
|
||||||
|
return d_loss, g_loss, generated_images
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Let's train our GAN, by repeatedly calling `train_step` on batches of images.
|
||||||
|
|
||||||
|
Since our discriminator and generator are convnets, you're going to want to
|
||||||
|
run this code on a GPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Prepare the dataset. We use both the training & test MNIST digits.
|
||||||
|
batch_size = 64
|
||||||
|
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
|
||||||
|
all_digits = np.concatenate([x_train, x_test])
|
||||||
|
all_digits = all_digits.astype("float32") / 255.0
|
||||||
|
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
|
||||||
|
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
|
||||||
|
|
||||||
|
epochs = 1 # In practice you need at least 20 epochs to generate nice digits.
|
||||||
|
save_dir = "./"
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
print(f"\nStart epoch {epoch}")
|
||||||
|
|
||||||
|
for step, real_images in enumerate(dataset):
|
||||||
|
# Train the discriminator & generator on one batch of real images.
|
||||||
|
d_loss, g_loss, generated_images = train_step(real_images)
|
||||||
|
|
||||||
|
# Logging.
|
||||||
|
if step % 200 == 0:
|
||||||
|
# Print metrics
|
||||||
|
print(f"discriminator loss at step {step}: {d_loss:.2f}")
|
||||||
|
print(f"adversarial loss at step {step}: {g_loss:.2f}")
|
||||||
|
|
||||||
|
# Save one generated image
|
||||||
|
img = keras.utils.array_to_img(
|
||||||
|
generated_images[0] * 255.0, scale=False
|
||||||
|
)
|
||||||
|
img.save(os.path.join(save_dir, f"generated_img_{step}.png"))
|
||||||
|
|
||||||
|
# To limit execution time we stop after 10 steps.
|
||||||
|
# Remove the lines below to actually train the model!
|
||||||
|
if step > 10:
|
||||||
|
break
|
||||||
|
|
||||||
|
"""
|
||||||
|
That's it! You'll get nice-looking fake MNIST digits after just ~30s of training on the
|
||||||
|
Colab GPU.
|
||||||
|
"""
|
381
guides/writing_a_custom_training_loop_in_torch.py
Normal file
381
guides/writing_a_custom_training_loop_in_torch.py
Normal file
@ -0,0 +1,381 @@
|
|||||||
|
"""
|
||||||
|
Title: Writing a training loop from scratch in PyTorch
|
||||||
|
Author: [fchollet](https://twitter.com/fchollet)
|
||||||
|
Date created: 2023/06/25
|
||||||
|
Last modified: 2023/06/25
|
||||||
|
Description: Writing low-level training & evaluation loops in PyTorch.
|
||||||
|
Accelerator: None
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
## Setup
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# This guide can only be run with the torch backend.
|
||||||
|
os.environ["KERAS_BACKEND"] = "torch"
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import keras_core as keras
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
Keras provides default training and evaluation loops, `fit()` and `evaluate()`.
|
||||||
|
Their usage is covered in the guide
|
||||||
|
[Training & evaluation with the built-in methods](https://keras.io/guides/training_with_built_in_methods/).
|
||||||
|
|
||||||
|
If you want to customize the learning algorithm of your model while still leveraging
|
||||||
|
the convenience of `fit()`
|
||||||
|
(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and
|
||||||
|
implement your own `train_step()` method, which
|
||||||
|
is called repeatedly during `fit()`.
|
||||||
|
|
||||||
|
Now, if you want very low-level control over training & evaluation, you should write
|
||||||
|
your own training & evaluation loops from scratch. This is what this guide is about.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
## A first end-to-end example
|
||||||
|
|
||||||
|
To write a custom training loop, you need the following ingredients:
|
||||||
|
|
||||||
|
- A model to train, of course.
|
||||||
|
- An optimizer. You could either use a `keras_core.optimizers` optimizer,
|
||||||
|
or a native PyTorch optimizer from `torch.optim`.
|
||||||
|
- A loss function. You could either use a `keras_core.losses` loss,
|
||||||
|
or a native PyTorch loss from `torch.nn`.
|
||||||
|
- A dataset. You could use any format: a `tf.data.Dataset`,
|
||||||
|
a PyTorch `DataLoader`, a Python generator, etc.
|
||||||
|
|
||||||
|
Let's line them up. We'll use torch-native objects in each case --
|
||||||
|
except, of course, for the Keras model.
|
||||||
|
|
||||||
|
First, let's get the model and the MNIST dataset:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Let's consider a simple MNIST model
|
||||||
|
def get_model():
|
||||||
|
inputs = keras.Input(shape=(784,), name="digits")
|
||||||
|
x1 = keras.layers.Dense(64, activation="relu")(inputs)
|
||||||
|
x2 = keras.layers.Dense(64, activation="relu")(x1)
|
||||||
|
outputs = keras.layers.Dense(10, name="predictions")(x2)
|
||||||
|
model = keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
# Create load up the MNIST dataset and put it in a torch DataLoader
|
||||||
|
# Prepare the training dataset.
|
||||||
|
batch_size = 32
|
||||||
|
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
|
||||||
|
x_train = np.reshape(x_train, (-1, 784))
|
||||||
|
x_test = np.reshape(x_test, (-1, 784))
|
||||||
|
y_train = keras.utils.to_categorical(y_train)
|
||||||
|
y_test = keras.utils.to_categorical(y_test)
|
||||||
|
|
||||||
|
# Reserve 10,000 samples for validation.
|
||||||
|
x_val = x_train[-10000:]
|
||||||
|
y_val = y_train[-10000:]
|
||||||
|
x_train = x_train[:-10000]
|
||||||
|
y_train = y_train[:-10000]
|
||||||
|
|
||||||
|
# Create torch Datasets
|
||||||
|
train_dataset = torch.utils.data.TensorDataset(
|
||||||
|
torch.from_numpy(x_train), torch.from_numpy(y_train)
|
||||||
|
)
|
||||||
|
val_dataset = torch.utils.data.TensorDataset(
|
||||||
|
torch.from_numpy(x_val), torch.from_numpy(y_val)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create DataLoaders for the Datasets
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset, batch_size=batch_size, shuffle=True
|
||||||
|
)
|
||||||
|
val_dataloader = torch.utils.data.DataLoader(
|
||||||
|
val_dataset, batch_size=batch_size, shuffle=False
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Next, here's our PyTorch optimizer and our PyTorch loss function:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Instantiate a torch optimizer
|
||||||
|
model = get_model()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
# Instantiate a torch loss function
|
||||||
|
loss_fn = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
"""
|
||||||
|
Let's train our model using mini-batch gradient with a custom training loop.
|
||||||
|
|
||||||
|
Calling `loss.backward()` on a loss tensor triggers backpropagation.
|
||||||
|
Once that's done, your optimizer is magically aware of the gradients for each variable
|
||||||
|
and can update its variables, which is done via `optimizer.step()`.
|
||||||
|
Tensors, variables, optimizers are all interconnected to one another via hidden global state.
|
||||||
|
Also, don't forget to call `model.zero_grad()` before `loss.backward()`, or you won't
|
||||||
|
get the right gradients for your variables.
|
||||||
|
|
||||||
|
Here's our training loop, step by step:
|
||||||
|
|
||||||
|
- We open a `for` loop that iterates over epochs
|
||||||
|
- For each epoch, we open a `for` loop that iterates over the dataset, in batches
|
||||||
|
- For each batch, we call the model on the input data to retrive the predictions,
|
||||||
|
then we use them to compute a loss value
|
||||||
|
- We call `loss.backward()` to
|
||||||
|
- Outside the scope, we retrieve the gradients of the weights
|
||||||
|
of the model with regard to the loss
|
||||||
|
- Finally, we use the optimizer to update the weights of the model based on the
|
||||||
|
gradients
|
||||||
|
"""
|
||||||
|
|
||||||
|
epochs = 3
|
||||||
|
for epoch in range(epochs):
|
||||||
|
for step, (inputs, targets) in enumerate(train_dataloader):
|
||||||
|
# Forward pass
|
||||||
|
logits = model(inputs)
|
||||||
|
loss = loss_fn(logits, targets)
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
model.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Optimizer variable updates
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# Log every 200 batches.
|
||||||
|
if step % 200 == 0:
|
||||||
|
print(
|
||||||
|
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
|
||||||
|
)
|
||||||
|
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||||
|
|
||||||
|
"""
|
||||||
|
As an alternative, let's look at what the loop looks like when using a Keras optimizer
|
||||||
|
and a Keras loss function.
|
||||||
|
|
||||||
|
Important differences:
|
||||||
|
|
||||||
|
- You retrieve the gradients for the variables via `v.value.grad`,
|
||||||
|
called on each trainable variable.
|
||||||
|
- You update your variables via `optimizer.apply_gradients()`, which must be
|
||||||
|
called in a `torch.no_grad()` scope.
|
||||||
|
|
||||||
|
**Also, a big gotcha:** while all NumPy/TensorFlow/JAX/Keras APIs
|
||||||
|
as well as Python `unittest` APIs use the argument order convention
|
||||||
|
`fn(y_true, y_pred)` (reference values first, predicted values second),
|
||||||
|
PyTorch actually uses `fn(y_pred, y_true)` for its losses.
|
||||||
|
So make sure to invert the order of `logits` and `targets`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = get_model()
|
||||||
|
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
|
||||||
|
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
print(f"\nStart of epoch {epoch}")
|
||||||
|
for step, (inputs, targets) in enumerate(train_dataloader):
|
||||||
|
# Forward pass
|
||||||
|
logits = model(inputs)
|
||||||
|
loss = loss_fn(targets, logits)
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
model.zero_grad()
|
||||||
|
trainable_weights = [v for v in model.trainable_weights]
|
||||||
|
|
||||||
|
# Call torch.Tensor.backward() on the loss to compute gradients
|
||||||
|
# for the weights.
|
||||||
|
loss.backward()
|
||||||
|
gradients = [v.value.grad for v in trainable_weights]
|
||||||
|
|
||||||
|
# Update weights
|
||||||
|
with torch.no_grad():
|
||||||
|
optimizer.apply_gradients(zip(gradients, trainable_weights))
|
||||||
|
|
||||||
|
# Log every 200 batches.
|
||||||
|
if step % 200 == 0:
|
||||||
|
print(
|
||||||
|
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
|
||||||
|
)
|
||||||
|
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Low-level handling of metrics
|
||||||
|
|
||||||
|
Let's add metrics monitoring to this basic training loop.
|
||||||
|
|
||||||
|
You can readily reuse built-in Keras metrics (or custom ones you wrote) in such training
|
||||||
|
loops written from scratch. Here's the flow:
|
||||||
|
|
||||||
|
- Instantiate the metric at the start of the loop
|
||||||
|
- Call `metric.update_state()` after each batch
|
||||||
|
- Call `metric.result()` when you need to display the current value of the metric
|
||||||
|
- Call `metric.reset_state()` when you need to clear the state of the metric
|
||||||
|
(typically at the end of an epoch)
|
||||||
|
|
||||||
|
Let's use this knowledge to compute `CategoricalAccuracy` on validation data at
|
||||||
|
the end of each epoch:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get a fresh model
|
||||||
|
model = get_model()
|
||||||
|
|
||||||
|
# Instantiate an optimizer to train the model.
|
||||||
|
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
|
||||||
|
# Instantiate a loss function.
|
||||||
|
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
|
||||||
|
|
||||||
|
# Prepare the metrics.
|
||||||
|
train_acc_metric = keras.metrics.CategoricalAccuracy()
|
||||||
|
val_acc_metric = keras.metrics.CategoricalAccuracy()
|
||||||
|
|
||||||
|
"""
|
||||||
|
Here's our training & evaluation loop:
|
||||||
|
"""
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
print(f"\nStart of epoch {epoch}")
|
||||||
|
for step, (inputs, targets) in enumerate(train_dataloader):
|
||||||
|
# Forward pass
|
||||||
|
logits = model(inputs)
|
||||||
|
loss = loss_fn(targets, logits)
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
model.zero_grad()
|
||||||
|
trainable_weights = [v for v in model.trainable_weights]
|
||||||
|
|
||||||
|
# Call torch.Tensor.backward() on the loss to compute gradients
|
||||||
|
# for the weights.
|
||||||
|
loss.backward()
|
||||||
|
gradients = [v.value.grad for v in trainable_weights]
|
||||||
|
|
||||||
|
# Update weights
|
||||||
|
with torch.no_grad():
|
||||||
|
optimizer.apply_gradients(zip(gradients, trainable_weights))
|
||||||
|
|
||||||
|
# Update training metric.
|
||||||
|
train_acc_metric.update_state(targets, logits)
|
||||||
|
|
||||||
|
# Log every 200 batches.
|
||||||
|
if step % 200 == 0:
|
||||||
|
print(
|
||||||
|
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
|
||||||
|
)
|
||||||
|
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||||
|
|
||||||
|
# Display metrics at the end of each epoch.
|
||||||
|
train_acc = train_acc_metric.result()
|
||||||
|
print(f"Training acc over epoch: {float(train_acc):.4f}")
|
||||||
|
|
||||||
|
# Reset training metrics at the end of each epoch
|
||||||
|
train_acc_metric.reset_state()
|
||||||
|
|
||||||
|
# Run a validation loop at the end of each epoch.
|
||||||
|
for x_batch_val, y_batch_val in val_dataloader:
|
||||||
|
val_logits = model(x_batch_val, training=False)
|
||||||
|
# Update val metrics
|
||||||
|
val_acc_metric.update_state(y_batch_val, val_logits)
|
||||||
|
val_acc = val_acc_metric.result()
|
||||||
|
val_acc_metric.reset_state()
|
||||||
|
print(f"Validation acc: {float(val_acc):.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
## Low-level handling of losses tracked by the model
|
||||||
|
|
||||||
|
Layers & models recursively track any losses created during the forward pass
|
||||||
|
by layers that call `self.add_loss(value)`. The resulting list of scalar loss
|
||||||
|
values are available via the property `model.losses`
|
||||||
|
at the end of the forward pass.
|
||||||
|
|
||||||
|
If you want to be using these loss components, you should sum them
|
||||||
|
and add them to the main loss in your training step.
|
||||||
|
|
||||||
|
Consider this layer, that creates an activity regularization loss:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ActivityRegularizationLayer(keras.layers.Layer):
|
||||||
|
def call(self, inputs):
|
||||||
|
self.add_loss(1e-2 * torch.sum(inputs))
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Let's build a really simple model that uses it:
|
||||||
|
"""
|
||||||
|
|
||||||
|
inputs = keras.Input(shape=(784,), name="digits")
|
||||||
|
x = keras.layers.Dense(64, activation="relu")(inputs)
|
||||||
|
# Insert activity regularization as a layer
|
||||||
|
x = ActivityRegularizationLayer()(x)
|
||||||
|
x = keras.layers.Dense(64, activation="relu")(x)
|
||||||
|
outputs = keras.layers.Dense(10, name="predictions")(x)
|
||||||
|
|
||||||
|
model = keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Here's what our training loop should look like now:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get a fresh model
|
||||||
|
model = get_model()
|
||||||
|
|
||||||
|
# Instantiate an optimizer to train the model.
|
||||||
|
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
|
||||||
|
# Instantiate a loss function.
|
||||||
|
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
|
||||||
|
|
||||||
|
# Prepare the metrics.
|
||||||
|
train_acc_metric = keras.metrics.CategoricalAccuracy()
|
||||||
|
val_acc_metric = keras.metrics.CategoricalAccuracy()
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
print(f"\nStart of epoch {epoch}")
|
||||||
|
for step, (inputs, targets) in enumerate(train_dataloader):
|
||||||
|
# Forward pass
|
||||||
|
logits = model(inputs)
|
||||||
|
loss = loss_fn(targets, logits)
|
||||||
|
if model.losses:
|
||||||
|
loss = loss + torch.sum(*model.losses)
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
model.zero_grad()
|
||||||
|
trainable_weights = [v for v in model.trainable_weights]
|
||||||
|
|
||||||
|
# Call torch.Tensor.backward() on the loss to compute gradients
|
||||||
|
# for the weights.
|
||||||
|
loss.backward()
|
||||||
|
gradients = [v.value.grad for v in trainable_weights]
|
||||||
|
|
||||||
|
# Update weights
|
||||||
|
with torch.no_grad():
|
||||||
|
optimizer.apply_gradients(zip(gradients, trainable_weights))
|
||||||
|
|
||||||
|
# Update training metric.
|
||||||
|
train_acc_metric.update_state(targets, logits)
|
||||||
|
|
||||||
|
# Log every 200 batches.
|
||||||
|
if step % 200 == 0:
|
||||||
|
print(
|
||||||
|
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
|
||||||
|
)
|
||||||
|
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||||
|
|
||||||
|
# Display metrics at the end of each epoch.
|
||||||
|
train_acc = train_acc_metric.result()
|
||||||
|
print(f"Training acc over epoch: {float(train_acc):.4f}")
|
||||||
|
|
||||||
|
# Reset training metrics at the end of each epoch
|
||||||
|
train_acc_metric.reset_state()
|
||||||
|
|
||||||
|
# Run a validation loop at the end of each epoch.
|
||||||
|
for x_batch_val, y_batch_val in val_dataloader:
|
||||||
|
val_logits = model(x_batch_val, training=False)
|
||||||
|
# Update val metrics
|
||||||
|
val_acc_metric.update_state(y_batch_val, val_logits)
|
||||||
|
val_acc = val_acc_metric.result()
|
||||||
|
val_acc_metric.reset_state()
|
||||||
|
print(f"Validation acc: {float(val_acc):.4f}")
|
@ -2,7 +2,7 @@
|
|||||||
Title: Writing your own callbacks
|
Title: Writing your own callbacks
|
||||||
Authors: Rick Chao, Francois Chollet
|
Authors: Rick Chao, Francois Chollet
|
||||||
Date created: 2019/03/20
|
Date created: 2019/03/20
|
||||||
Last modified: 2020/07/12
|
Last modified: 2023/06/25
|
||||||
Description: Complete guide to writing new Keras callbacks.
|
Description: Complete guide to writing new Keras callbacks.
|
||||||
Accelerator: GPU
|
Accelerator: GPU
|
||||||
"""
|
"""
|
||||||
|
@ -28,7 +28,7 @@ class Variable(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _direct_assign(self, value):
|
def _direct_assign(self, value):
|
||||||
self._value.assign(tf.cast(value, self._value.dtype))
|
self.value.assign(value)
|
||||||
|
|
||||||
def _convert_to_tensor(self, value, dtype=None):
|
def _convert_to_tensor(self, value, dtype=None):
|
||||||
return convert_to_tensor(value, dtype=dtype)
|
return convert_to_tensor(value, dtype=dtype)
|
||||||
|
@ -129,9 +129,9 @@ def convert_to_tensor(x, dtype=None):
|
|||||||
# Convert to np in case of any array-like that is not list or tuple.
|
# Convert to np in case of any array-like that is not list or tuple.
|
||||||
if not isinstance(x, (list, tuple)):
|
if not isinstance(x, (list, tuple)):
|
||||||
x = np.array(x)
|
x = np.array(x)
|
||||||
elif len(x) > 0 and any(isinstance(x1, torch.Tensor) for x1 in x):
|
elif len(x) > 0 and isinstance(x[0], torch.Tensor):
|
||||||
# Handle list or tuple of torch tensors
|
# Handle list or tuple of torch tensors
|
||||||
return torch.stack([convert_to_tensor(x1) for x1 in x])
|
return torch.stack(x)
|
||||||
if isinstance(x, np.ndarray) and x.dtype == np.uint32:
|
if isinstance(x, np.ndarray) and x.dtype == np.uint32:
|
||||||
# Torch backend does not support uint32.
|
# Torch backend does not support uint32.
|
||||||
x = x.astype(np.int64)
|
x = x.astype(np.int64)
|
||||||
|
@ -239,9 +239,5 @@ class CoreOpsCorrectnessTest(testing.TestCase):
|
|||||||
self.assertAllEqual(x, (1, 1))
|
self.assertAllEqual(x, (1, 1))
|
||||||
self.assertIsInstance(x, np.ndarray)
|
self.assertIsInstance(x, np.ndarray)
|
||||||
|
|
||||||
# Partially converted.
|
|
||||||
x = ops.convert_to_tensor((1, ops.array(2), 3))
|
|
||||||
self.assertAllEqual(x, (1, 2, 3))
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
ops.convert_to_numpy(KerasTensor((2,)))
|
ops.convert_to_numpy(KerasTensor((2,)))
|
||||||
|
@ -5,6 +5,7 @@ import inspect
|
|||||||
import types
|
import types
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@ -161,7 +162,11 @@ def serialize_keras_object(obj):
|
|||||||
}
|
}
|
||||||
if isinstance(obj, tf.TensorShape):
|
if isinstance(obj, tf.TensorShape):
|
||||||
return obj.as_list() if obj._dims is not None else None
|
return obj.as_list() if obj._dims is not None else None
|
||||||
if backend.is_tensor(obj):
|
if isinstance(obj, (tf.Tensor, jax.numpy.ndarray)) or hasattr(
|
||||||
|
obj, "device"
|
||||||
|
):
|
||||||
|
# Import torch creates circular dependency, so we use
|
||||||
|
# `hasattr(obj, "device")` to check if obj is a torch tensor.
|
||||||
return {
|
return {
|
||||||
"class_name": "__tensor__",
|
"class_name": "__tensor__",
|
||||||
"config": {
|
"config": {
|
||||||
|
@ -27,7 +27,11 @@ class ArrayDataAdapter(DataAdapter):
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
class_weight=None,
|
class_weight=None,
|
||||||
):
|
):
|
||||||
if not can_convert_arrays((x, y, sample_weight)):
|
types_struct = nest.map_structure(lambda x: type(x), x)
|
||||||
|
flat_types = nest.flatten(types_struct)
|
||||||
|
if not all(
|
||||||
|
issubclass(c, data_adapter_utils.ARRAY_TYPES) for c in flat_types
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected all elements of `x` to be array-like. "
|
"Expected all elements of `x` to be array-like. "
|
||||||
f"Received invalid types: x={x}"
|
f"Received invalid types: x={x}"
|
||||||
@ -248,28 +252,6 @@ class ArrayDataAdapter(DataAdapter):
|
|||||||
return self._partial_batch_size or None
|
return self._partial_batch_size or None
|
||||||
|
|
||||||
|
|
||||||
def can_convert_arrays(arrays):
|
|
||||||
"""Check if array like-inputs can be handled by `ArrayDataAdapter`
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`True` if `arrays` can be handled by `ArrayDataAdapter`, `False`
|
|
||||||
otherwise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def can_convert_single_array(x):
|
|
||||||
is_none = x is None
|
|
||||||
known_type = isinstance(x, data_adapter_utils.ARRAY_TYPES)
|
|
||||||
convertable_type = hasattr(x, "__array__")
|
|
||||||
return is_none or known_type or convertable_type
|
|
||||||
|
|
||||||
return all(
|
|
||||||
tf.nest.flatten(tf.nest.map_structure(can_convert_single_array, arrays))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_arrays(arrays, dtype=None):
|
def convert_to_arrays(arrays, dtype=None):
|
||||||
"""Process array-like inputs.
|
"""Process array-like inputs.
|
||||||
|
|
||||||
@ -280,7 +262,7 @@ def convert_to_arrays(arrays, dtype=None):
|
|||||||
- Converts `list`s to `tuple`s (for `tf.data` support).
|
- Converts `list`s to `tuple`s (for `tf.data` support).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like.
|
inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Structure of NumPy `ndarray`s.
|
Structure of NumPy `ndarray`s.
|
||||||
@ -295,16 +277,15 @@ def convert_to_arrays(arrays, dtype=None):
|
|||||||
x = np.expand_dims(x.to_numpy(dtype=dtype), axis=-1)
|
x = np.expand_dims(x.to_numpy(dtype=dtype), axis=-1)
|
||||||
elif isinstance(x, pandas.DataFrame):
|
elif isinstance(x, pandas.DataFrame):
|
||||||
x = x.to_numpy(dtype=dtype)
|
x = x.to_numpy(dtype=dtype)
|
||||||
|
if isinstance(x, (tf.Tensor, tf.Variable)):
|
||||||
|
x = x.numpy()
|
||||||
if not isinstance(x, np.ndarray):
|
if not isinstance(x, np.ndarray):
|
||||||
# Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`,
|
|
||||||
# `torch.Tensor`, as well as any other tensor-like object that has
|
|
||||||
# added numpy support.
|
|
||||||
if hasattr(x, "__array__"):
|
if hasattr(x, "__array__"):
|
||||||
x = np.array(x, dtype=dtype)
|
x = np.array(x, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected a NumPy array, tf.Tensor, jax.np.ndarray, "
|
"Expected a NumPy array, tf.Tensor, "
|
||||||
"torch.Tensor, Pandas Dataframe, or Pandas Series. "
|
"Pandas Dataframe, or Pandas Series. "
|
||||||
f"Received invalid input: {x} (of type {type(x)})"
|
f"Received invalid input: {x} (of type {type(x)})"
|
||||||
)
|
)
|
||||||
if x.dtype == object:
|
if x.dtype == object:
|
||||||
|
@ -2,7 +2,6 @@ import jax
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas
|
import pandas
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from keras_core import backend
|
from keras_core import backend
|
||||||
@ -11,9 +10,7 @@ from keras_core.trainers.data_adapters import array_data_adapter
|
|||||||
|
|
||||||
|
|
||||||
class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase):
|
class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase):
|
||||||
@parameterized.parameters(
|
@parameterized.parameters([("np",), ("tf",), ("pandas")])
|
||||||
[("np",), ("tf",), ("jax",), ("torch",), ("pandas")]
|
|
||||||
)
|
|
||||||
def test_basic_flow(self, array_type):
|
def test_basic_flow(self, array_type):
|
||||||
if array_type == "np":
|
if array_type == "np":
|
||||||
x = np.random.random((34, 4))
|
x = np.random.random((34, 4))
|
||||||
@ -24,9 +21,6 @@ class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase):
|
|||||||
elif array_type == "jax":
|
elif array_type == "jax":
|
||||||
x = jax.numpy.ones((34, 4))
|
x = jax.numpy.ones((34, 4))
|
||||||
y = jax.numpy.ones((34, 2))
|
y = jax.numpy.ones((34, 2))
|
||||||
elif array_type == "torch":
|
|
||||||
x = torch.ones((34, 4))
|
|
||||||
y = torch.ones((34, 2))
|
|
||||||
elif array_type == "pandas":
|
elif array_type == "pandas":
|
||||||
x = pandas.DataFrame(np.random.random((34, 4)))
|
x = pandas.DataFrame(np.random.random((34, 4)))
|
||||||
y = pandas.DataFrame(np.random.random((34, 2)))
|
y = pandas.DataFrame(np.random.random((34, 2)))
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
|
import jax
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@ -11,12 +12,15 @@ except ImportError:
|
|||||||
pandas = None
|
pandas = None
|
||||||
|
|
||||||
|
|
||||||
# Leave jax, tf, and torch arrays off this list. Instead we will use
|
ARRAY_TYPES = (tf.Tensor, np.ndarray, jax.numpy.ndarray)
|
||||||
# `__array__` to detect these types. Doing so allows us to avoid importing a
|
|
||||||
# backend framework we are not currently using just to do type-checking.
|
|
||||||
ARRAY_TYPES = (np.ndarray,)
|
|
||||||
if pandas:
|
if pandas:
|
||||||
ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame)
|
ARRAY_TYPES = ARRAY_TYPES + (
|
||||||
|
tf.Tensor,
|
||||||
|
np.ndarray,
|
||||||
|
pandas.Series,
|
||||||
|
pandas.DataFrame,
|
||||||
|
)
|
||||||
|
# TODO: support torch tensors?
|
||||||
|
|
||||||
|
|
||||||
@keras_core_export("keras_core.utils.unpack_x_y_sample_weight")
|
@keras_core_export("keras_core.utils.unpack_x_y_sample_weight")
|
||||||
|
@ -42,8 +42,10 @@ import types
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow import nest
|
||||||
|
|
||||||
from keras_core.trainers.data_adapters import array_data_adapter
|
from keras_core.trainers.data_adapters import array_data_adapter
|
||||||
|
from keras_core.trainers.data_adapters import data_adapter_utils
|
||||||
from keras_core.trainers.data_adapters import generator_data_adapter
|
from keras_core.trainers.data_adapters import generator_data_adapter
|
||||||
from keras_core.trainers.data_adapters import py_dataset_adapter
|
from keras_core.trainers.data_adapters import py_dataset_adapter
|
||||||
from keras_core.trainers.data_adapters import tf_dataset_adapter
|
from keras_core.trainers.data_adapters import tf_dataset_adapter
|
||||||
@ -67,7 +69,8 @@ class EpochIterator:
|
|||||||
if steps_per_epoch:
|
if steps_per_epoch:
|
||||||
self._current_iterator = None
|
self._current_iterator = None
|
||||||
self._insufficient_data = False
|
self._insufficient_data = False
|
||||||
if array_data_adapter.can_convert_arrays((x, y, sample_weight)):
|
first_element = next(iter(nest.flatten(x)), None)
|
||||||
|
if isinstance(first_element, data_adapter_utils.ARRAY_TYPES):
|
||||||
self.data_adapter = array_data_adapter.ArrayDataAdapter(
|
self.data_adapter = array_data_adapter.ArrayDataAdapter(
|
||||||
x,
|
x,
|
||||||
y,
|
y,
|
||||||
|
Loading…
Reference in New Issue
Block a user