Add torch and TF custom training loops guides
This commit is contained in:
parent
9e9ae0f65e
commit
6ec8a6160c
@ -1,7 +0,0 @@
|
||||
try:
|
||||
# When using torch and tensorflow, torch needs to be imported first,
|
||||
# otherwise it will segfault upon import. This should force the torch
|
||||
# import to happen first for all tests.
|
||||
import torch # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
@ -2,7 +2,7 @@
|
||||
Title: Making new layers and models via subclassing
|
||||
Author: [fchollet](https://twitter.com/fchollet)
|
||||
Date created: 2019/03/01
|
||||
Last modified: 2023/06/21
|
||||
Last modified: 2023/06/25
|
||||
Description: Complete guide to writing `Layer` and `Model` objects from scratch.
|
||||
Accelerator: None
|
||||
"""
|
||||
|
@ -2,7 +2,7 @@
|
||||
Title: The Sequential model
|
||||
Author: [fchollet](https://twitter.com/fchollet)
|
||||
Date created: 2020/04/12
|
||||
Last modified: 2020/04/12
|
||||
Last modified: 2023/06/25
|
||||
Description: Complete guide to the Sequential model.
|
||||
Accelerator: GPU
|
||||
"""
|
||||
|
@ -2,7 +2,7 @@
|
||||
Title: Training & evaluation with the built-in methods
|
||||
Author: [fchollet](https://twitter.com/fchollet)
|
||||
Date created: 2019/03/01
|
||||
Last modified: 2023/03/20
|
||||
Last modified: 2023/03/25
|
||||
Description: Complete guide to training & evaluation with `fit()` and `evaluate()`.
|
||||
Accelerator: GPU
|
||||
"""
|
||||
|
@ -2,7 +2,7 @@
|
||||
Title: Understanding masking & padding
|
||||
Authors: Scott Zhu, Francois Chollet
|
||||
Date created: 2019/07/16
|
||||
Last modified: 2020/04/14
|
||||
Last modified: 2023/06/25
|
||||
Description: Complete guide to using mask-aware sequence layers in Keras.
|
||||
Accelerator: None
|
||||
"""
|
||||
@ -376,5 +376,4 @@ automatically.
|
||||
manually.
|
||||
- You can easily write layers that modify the current mask, that generate a new mask,
|
||||
or that consume the mask associated with the inputs.
|
||||
|
||||
"""
|
||||
|
531
guides/writing_a_custom_training_loop_in_tensorflow.py
Normal file
531
guides/writing_a_custom_training_loop_in_tensorflow.py
Normal file
@ -0,0 +1,531 @@
|
||||
"""
|
||||
Title: Writing a training loop from scratch in TensorFlow
|
||||
Author: [fchollet](https://twitter.com/fchollet)
|
||||
Date created: 2019/03/01
|
||||
Last modified: 2023/06/25
|
||||
Description: Writing low-level training & evaluation loops in TensorFlow.
|
||||
Accelerator: None
|
||||
"""
|
||||
"""
|
||||
## Setup
|
||||
"""
|
||||
|
||||
import time
|
||||
import os
|
||||
|
||||
# This guide can only be run with the TensorFlow backend.
|
||||
os.environ["KERAS_BACKEND"] = "tensorflow"
|
||||
|
||||
import tensorflow as tf
|
||||
import keras_core as keras
|
||||
import numpy as np
|
||||
|
||||
"""
|
||||
## Introduction
|
||||
|
||||
Keras provides default training and evaluation loops, `fit()` and `evaluate()`.
|
||||
Their usage is covered in the guide
|
||||
[Training & evaluation with the built-in methods](https://keras.io/guides/training_with_built_in_methods/).
|
||||
|
||||
If you want to customize the learning algorithm of your model while still leveraging
|
||||
the convenience of `fit()`
|
||||
(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and
|
||||
implement your own `train_step()` method, which
|
||||
is called repeatedly during `fit()`.
|
||||
|
||||
Now, if you want very low-level control over training & evaluation, you should write
|
||||
your own training & evaluation loops from scratch. This is what this guide is about.
|
||||
"""
|
||||
|
||||
"""
|
||||
## A first end-to-end example
|
||||
|
||||
Calling a model inside a `GradientTape` scope enables you to retrieve the gradients of
|
||||
the trainable weights of the layer with respect to a loss value. Using an optimizer
|
||||
instance, you can use these gradients to update these variables (which you can
|
||||
retrieve using `model.trainable_weights`).
|
||||
|
||||
Let's consider a simple MNIST model:
|
||||
"""
|
||||
|
||||
|
||||
def get_model():
|
||||
inputs = keras.Input(shape=(784,), name="digits")
|
||||
x1 = keras.layers.Dense(64, activation="relu")(inputs)
|
||||
x2 = keras.layers.Dense(64, activation="relu")(x1)
|
||||
outputs = keras.layers.Dense(10, name="predictions")(x2)
|
||||
model = keras.Model(inputs=inputs, outputs=outputs)
|
||||
return model
|
||||
|
||||
|
||||
model = get_model()
|
||||
|
||||
"""
|
||||
Let's train it using mini-batch gradient with a custom training loop.
|
||||
|
||||
First, we're going to need an optimizer, a loss function, and a dataset:
|
||||
"""
|
||||
|
||||
# Instantiate an optimizer.
|
||||
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
|
||||
# Instantiate a loss function.
|
||||
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
|
||||
# Prepare the training dataset.
|
||||
batch_size = 32
|
||||
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
|
||||
x_train = np.reshape(x_train, (-1, 784))
|
||||
x_test = np.reshape(x_test, (-1, 784))
|
||||
|
||||
# Reserve 10,000 samples for validation.
|
||||
x_val = x_train[-10000:]
|
||||
y_val = y_train[-10000:]
|
||||
x_train = x_train[:-10000]
|
||||
y_train = y_train[:-10000]
|
||||
|
||||
# Prepare the training dataset.
|
||||
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
||||
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
|
||||
|
||||
# Prepare the validation dataset.
|
||||
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
|
||||
val_dataset = val_dataset.batch(batch_size)
|
||||
|
||||
"""
|
||||
Here's our training loop, step by step:
|
||||
|
||||
- We open a `for` loop that iterates over epochs
|
||||
- For each epoch, we open a `for` loop that iterates over the dataset, in batches
|
||||
- For each batch, we open a `GradientTape()` scope
|
||||
- Inside this scope, we call the model (forward pass) and compute the loss
|
||||
- Outside the scope, we retrieve the gradients of the weights
|
||||
of the model with regard to the loss
|
||||
- Finally, we use the optimizer to update the weights of the model based on the
|
||||
gradients
|
||||
"""
|
||||
|
||||
epochs = 3
|
||||
for epoch in range(epochs):
|
||||
print(f"\nStart of epoch {epoch}")
|
||||
|
||||
# Iterate over the batches of the dataset.
|
||||
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
|
||||
# Open a GradientTape to record the operations run
|
||||
# during the forward pass, which enables auto-differentiation.
|
||||
with tf.GradientTape() as tape:
|
||||
# Run the forward pass of the layer.
|
||||
# The operations that the layer applies
|
||||
# to its inputs are going to be recorded
|
||||
# on the GradientTape.
|
||||
logits = model(
|
||||
x_batch_train, training=True
|
||||
) # Logits for this minibatch
|
||||
|
||||
# Compute the loss value for this minibatch.
|
||||
loss_value = loss_fn(y_batch_train, logits)
|
||||
|
||||
# Use the gradient tape to automatically retrieve
|
||||
# the gradients of the trainable variables with respect to the loss.
|
||||
grads = tape.gradient(loss_value, model.trainable_weights)
|
||||
|
||||
# Run one step of gradient descent by updating
|
||||
# the value of the variables to minimize the loss.
|
||||
optimizer.apply_gradients(zip(grads, model.trainable_weights))
|
||||
|
||||
# Log every 200 batches.
|
||||
if step % 200 == 0:
|
||||
print(
|
||||
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
|
||||
)
|
||||
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||
|
||||
"""
|
||||
## Low-level handling of metrics
|
||||
|
||||
Let's add metrics monitoring to this basic loop.
|
||||
|
||||
You can readily reuse the built-in metrics (or custom ones you wrote) in such training
|
||||
loops written from scratch. Here's the flow:
|
||||
|
||||
- Instantiate the metric at the start of the loop
|
||||
- Call `metric.update_state()` after each batch
|
||||
- Call `metric.result()` when you need to display the current value of the metric
|
||||
- Call `metric.reset_state()` when you need to clear the state of the metric
|
||||
(typically at the end of an epoch)
|
||||
|
||||
Let's use this knowledge to compute `SparseCategoricalAccuracy` on validation data at
|
||||
the end of each epoch:
|
||||
"""
|
||||
|
||||
# Get a fresh model
|
||||
model = get_model()
|
||||
|
||||
# Instantiate an optimizer to train the model.
|
||||
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
|
||||
# Instantiate a loss function.
|
||||
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
|
||||
# Prepare the metrics.
|
||||
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
|
||||
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()
|
||||
|
||||
"""
|
||||
Here's our training & evaluation loop:
|
||||
"""
|
||||
|
||||
epochs = 2
|
||||
for epoch in range(epochs):
|
||||
print(f"\nStart of epoch {epoch}")
|
||||
start_time = time.time()
|
||||
|
||||
# Iterate over the batches of the dataset.
|
||||
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
|
||||
with tf.GradientTape() as tape:
|
||||
logits = model(x_batch_train, training=True)
|
||||
loss_value = loss_fn(y_batch_train, logits)
|
||||
grads = tape.gradient(loss_value, model.trainable_weights)
|
||||
optimizer.apply_gradients(zip(grads, model.trainable_weights))
|
||||
|
||||
# Update training metric.
|
||||
train_acc_metric.update_state(y_batch_train, logits)
|
||||
|
||||
# Log every 200 batches.
|
||||
if step % 200 == 0:
|
||||
print(
|
||||
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
|
||||
)
|
||||
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||
|
||||
# Display metrics at the end of each epoch.
|
||||
train_acc = train_acc_metric.result()
|
||||
print(f"Training acc over epoch: {float(train_acc):.4f}")
|
||||
|
||||
# Reset training metrics at the end of each epoch
|
||||
train_acc_metric.reset_state()
|
||||
|
||||
# Run a validation loop at the end of each epoch.
|
||||
for x_batch_val, y_batch_val in val_dataset:
|
||||
val_logits = model(x_batch_val, training=False)
|
||||
# Update val metrics
|
||||
val_acc_metric.update_state(y_batch_val, val_logits)
|
||||
val_acc = val_acc_metric.result()
|
||||
val_acc_metric.reset_state()
|
||||
print(f"Validation acc: {float(val_acc):.4f}")
|
||||
print(f"Time taken: {time.time() - start_time:.2f}s")
|
||||
|
||||
"""
|
||||
## Speeding-up your training step with `tf.function`
|
||||
|
||||
The default runtime in TensorFlow is eager execution.
|
||||
As such, our training loop above executes eagerly.
|
||||
|
||||
This is great for debugging, but graph compilation has a definite performance
|
||||
advantage. Describing your computation as a static graph enables the framework
|
||||
to apply global performance optimizations. This is impossible when
|
||||
the framework is constrained to greedily execute one operation after another,
|
||||
with no knowledge of what comes next.
|
||||
|
||||
You can compile into a static graph any function that takes tensors as input.
|
||||
Just add a `@tf.function` decorator on it, like this:
|
||||
"""
|
||||
|
||||
|
||||
@tf.function
|
||||
def train_step(x, y):
|
||||
with tf.GradientTape() as tape:
|
||||
logits = model(x, training=True)
|
||||
loss_value = loss_fn(y, logits)
|
||||
grads = tape.gradient(loss_value, model.trainable_weights)
|
||||
optimizer.apply_gradients(zip(grads, model.trainable_weights))
|
||||
train_acc_metric.update_state(y, logits)
|
||||
return loss_value
|
||||
|
||||
|
||||
"""
|
||||
Let's do the same with the evaluation step:
|
||||
"""
|
||||
|
||||
|
||||
@tf.function
|
||||
def test_step(x, y):
|
||||
val_logits = model(x, training=False)
|
||||
val_acc_metric.update_state(y, val_logits)
|
||||
|
||||
|
||||
"""
|
||||
Now, let's re-run our training loop with this compiled training step:
|
||||
"""
|
||||
|
||||
epochs = 2
|
||||
for epoch in range(epochs):
|
||||
print(f"\nStart of epoch {epoch}")
|
||||
start_time = time.time()
|
||||
|
||||
# Iterate over the batches of the dataset.
|
||||
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
|
||||
loss_value = train_step(x_batch_train, y_batch_train)
|
||||
|
||||
# Log every 200 batches.
|
||||
if step % 200 == 0:
|
||||
print(
|
||||
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
|
||||
)
|
||||
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||
|
||||
# Display metrics at the end of each epoch.
|
||||
train_acc = train_acc_metric.result()
|
||||
print(f"Training acc over epoch: {float(train_acc):.4f}")
|
||||
|
||||
# Reset training metrics at the end of each epoch
|
||||
train_acc_metric.reset_state()
|
||||
|
||||
# Run a validation loop at the end of each epoch.
|
||||
for x_batch_val, y_batch_val in val_dataset:
|
||||
test_step(x_batch_val, y_batch_val)
|
||||
|
||||
val_acc = val_acc_metric.result()
|
||||
val_acc_metric.reset_state()
|
||||
print(f"Validation acc: {float(val_acc):.4f}")
|
||||
print(f"Time taken: {time.time() - start_time:.2f}s")
|
||||
|
||||
"""
|
||||
Much faster, isn't it?
|
||||
"""
|
||||
|
||||
"""
|
||||
## Low-level handling of losses tracked by the model
|
||||
|
||||
Layers & models recursively track any losses created during the forward pass
|
||||
by layers that call `self.add_loss(value)`. The resulting list of scalar loss
|
||||
values are available via the property `model.losses`
|
||||
at the end of the forward pass.
|
||||
|
||||
If you want to be using these loss components, you should sum them
|
||||
and add them to the main loss in your training step.
|
||||
|
||||
Consider this layer, that creates an activity regularization loss:
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ActivityRegularizationLayer(keras.layers.Layer):
|
||||
def call(self, inputs):
|
||||
self.add_loss(1e-2 * tf.reduce_sum(inputs))
|
||||
return inputs
|
||||
|
||||
|
||||
"""
|
||||
Let's build a really simple model that uses it:
|
||||
"""
|
||||
|
||||
inputs = keras.Input(shape=(784,), name="digits")
|
||||
x = keras.layers.Dense(64, activation="relu")(inputs)
|
||||
# Insert activity regularization as a layer
|
||||
x = ActivityRegularizationLayer()(x)
|
||||
x = keras.layers.Dense(64, activation="relu")(x)
|
||||
outputs = keras.layers.Dense(10, name="predictions")(x)
|
||||
|
||||
model = keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
"""
|
||||
Here's what our training step should look like now:
|
||||
"""
|
||||
|
||||
|
||||
@tf.function
|
||||
def train_step(x, y):
|
||||
with tf.GradientTape() as tape:
|
||||
logits = model(x, training=True)
|
||||
loss_value = loss_fn(y, logits)
|
||||
# Add any extra losses created during the forward pass.
|
||||
loss_value += sum(model.losses)
|
||||
grads = tape.gradient(loss_value, model.trainable_weights)
|
||||
optimizer.apply_gradients(zip(grads, model.trainable_weights))
|
||||
train_acc_metric.update_state(y, logits)
|
||||
return loss_value
|
||||
|
||||
|
||||
"""
|
||||
## Summary
|
||||
|
||||
Now you know everything there is to know about using built-in training loops and
|
||||
writing your own from scratch.
|
||||
|
||||
To conclude, here's a simple end-to-end example that ties together everything
|
||||
you've learned in this guide: a DCGAN trained on MNIST digits.
|
||||
"""
|
||||
|
||||
"""
|
||||
## End-to-end example: a GAN training loop from scratch
|
||||
|
||||
You may be familiar with Generative Adversarial Networks (GANs). GANs can generate new
|
||||
images that look almost real, by learning the latent distribution of a training
|
||||
dataset of images (the "latent space" of the images).
|
||||
|
||||
A GAN is made of two parts: a "generator" model that maps points in the latent
|
||||
space to points in image space, a "discriminator" model, a classifier
|
||||
that can tell the difference between real images (from the training dataset)
|
||||
and fake images (the output of the generator network).
|
||||
|
||||
A GAN training loop looks like this:
|
||||
|
||||
1) Train the discriminator.
|
||||
- Sample a batch of random points in the latent space.
|
||||
- Turn the points into fake images via the "generator" model.
|
||||
- Get a batch of real images and combine them with the generated images.
|
||||
- Train the "discriminator" model to classify generated vs. real images.
|
||||
|
||||
2) Train the generator.
|
||||
- Sample random points in the latent space.
|
||||
- Turn the points into fake images via the "generator" network.
|
||||
- Get a batch of real images and combine them with the generated images.
|
||||
- Train the "generator" model to "fool" the discriminator and classify the fake images
|
||||
as real.
|
||||
|
||||
For a much more detailed overview of how GANs works, see
|
||||
[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python).
|
||||
|
||||
Let's implement this training loop. First, create the discriminator meant to classify
|
||||
fake vs real digits:
|
||||
"""
|
||||
|
||||
discriminator = keras.Sequential(
|
||||
[
|
||||
keras.Input(shape=(28, 28, 1)),
|
||||
keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
|
||||
keras.layers.LeakyReLU(negative_slope=0.2),
|
||||
keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
|
||||
keras.layers.LeakyReLU(negative_slope=0.2),
|
||||
keras.layers.GlobalMaxPooling2D(),
|
||||
keras.layers.Dense(1),
|
||||
],
|
||||
name="discriminator",
|
||||
)
|
||||
discriminator.summary()
|
||||
|
||||
"""
|
||||
Then let's create a generator network,
|
||||
that turns latent vectors into outputs of shape `(28, 28, 1)` (representing
|
||||
MNIST digits):
|
||||
"""
|
||||
|
||||
latent_dim = 128
|
||||
|
||||
generator = keras.Sequential(
|
||||
[
|
||||
keras.Input(shape=(latent_dim,)),
|
||||
# We want to generate 128 coefficients to reshape into a 7x7x128 map
|
||||
keras.layers.Dense(7 * 7 * 128),
|
||||
keras.layers.LeakyReLU(negative_slope=0.2),
|
||||
keras.layers.Reshape((7, 7, 128)),
|
||||
keras.layers.Conv2DTranspose(
|
||||
128, (4, 4), strides=(2, 2), padding="same"
|
||||
),
|
||||
keras.layers.LeakyReLU(negative_slope=0.2),
|
||||
keras.layers.Conv2DTranspose(
|
||||
128, (4, 4), strides=(2, 2), padding="same"
|
||||
),
|
||||
keras.layers.LeakyReLU(negative_slope=0.2),
|
||||
keras.layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
|
||||
],
|
||||
name="generator",
|
||||
)
|
||||
|
||||
"""
|
||||
Here's the key bit: the training loop. As you can see it is quite straightforward. The
|
||||
training step function only takes 17 lines.
|
||||
"""
|
||||
|
||||
# Instantiate one optimizer for the discriminator and another for the generator.
|
||||
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)
|
||||
g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)
|
||||
|
||||
# Instantiate a loss function.
|
||||
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
|
||||
|
||||
|
||||
@tf.function
|
||||
def train_step(real_images):
|
||||
# Sample random points in the latent space
|
||||
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
|
||||
# Decode them to fake images
|
||||
generated_images = generator(random_latent_vectors)
|
||||
# Combine them with real images
|
||||
combined_images = tf.concat([generated_images, real_images], axis=0)
|
||||
|
||||
# Assemble labels discriminating real from fake images
|
||||
labels = tf.concat(
|
||||
[tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
|
||||
)
|
||||
# Add random noise to the labels - important trick!
|
||||
labels += 0.05 * tf.random.uniform(labels.shape)
|
||||
|
||||
# Train the discriminator
|
||||
with tf.GradientTape() as tape:
|
||||
predictions = discriminator(combined_images)
|
||||
d_loss = loss_fn(labels, predictions)
|
||||
grads = tape.gradient(d_loss, discriminator.trainable_weights)
|
||||
d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))
|
||||
|
||||
# Sample random points in the latent space
|
||||
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
|
||||
# Assemble labels that say "all real images"
|
||||
misleading_labels = tf.zeros((batch_size, 1))
|
||||
|
||||
# Train the generator (note that we should *not* update the weights
|
||||
# of the discriminator)!
|
||||
with tf.GradientTape() as tape:
|
||||
predictions = discriminator(generator(random_latent_vectors))
|
||||
g_loss = loss_fn(misleading_labels, predictions)
|
||||
grads = tape.gradient(g_loss, generator.trainable_weights)
|
||||
g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
|
||||
return d_loss, g_loss, generated_images
|
||||
|
||||
|
||||
"""
|
||||
Let's train our GAN, by repeatedly calling `train_step` on batches of images.
|
||||
|
||||
Since our discriminator and generator are convnets, you're going to want to
|
||||
run this code on a GPU.
|
||||
"""
|
||||
|
||||
# Prepare the dataset. We use both the training & test MNIST digits.
|
||||
batch_size = 64
|
||||
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
|
||||
all_digits = np.concatenate([x_train, x_test])
|
||||
all_digits = all_digits.astype("float32") / 255.0
|
||||
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
|
||||
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
|
||||
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
|
||||
|
||||
epochs = 1 # In practice you need at least 20 epochs to generate nice digits.
|
||||
save_dir = "./"
|
||||
|
||||
for epoch in range(epochs):
|
||||
print(f"\nStart epoch {epoch}")
|
||||
|
||||
for step, real_images in enumerate(dataset):
|
||||
# Train the discriminator & generator on one batch of real images.
|
||||
d_loss, g_loss, generated_images = train_step(real_images)
|
||||
|
||||
# Logging.
|
||||
if step % 200 == 0:
|
||||
# Print metrics
|
||||
print(f"discriminator loss at step {step}: {d_loss:.2f}")
|
||||
print(f"adversarial loss at step {step}: {g_loss:.2f}")
|
||||
|
||||
# Save one generated image
|
||||
img = keras.utils.array_to_img(
|
||||
generated_images[0] * 255.0, scale=False
|
||||
)
|
||||
img.save(os.path.join(save_dir, f"generated_img_{step}.png"))
|
||||
|
||||
# To limit execution time we stop after 10 steps.
|
||||
# Remove the lines below to actually train the model!
|
||||
if step > 10:
|
||||
break
|
||||
|
||||
"""
|
||||
That's it! You'll get nice-looking fake MNIST digits after just ~30s of training on the
|
||||
Colab GPU.
|
||||
"""
|
381
guides/writing_a_custom_training_loop_in_torch.py
Normal file
381
guides/writing_a_custom_training_loop_in_torch.py
Normal file
@ -0,0 +1,381 @@
|
||||
"""
|
||||
Title: Writing a training loop from scratch in PyTorch
|
||||
Author: [fchollet](https://twitter.com/fchollet)
|
||||
Date created: 2023/06/25
|
||||
Last modified: 2023/06/25
|
||||
Description: Writing low-level training & evaluation loops in PyTorch.
|
||||
Accelerator: None
|
||||
"""
|
||||
"""
|
||||
## Setup
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# This guide can only be run with the torch backend.
|
||||
os.environ["KERAS_BACKEND"] = "torch"
|
||||
|
||||
import torch
|
||||
import keras_core as keras
|
||||
import numpy as np
|
||||
|
||||
"""
|
||||
## Introduction
|
||||
|
||||
Keras provides default training and evaluation loops, `fit()` and `evaluate()`.
|
||||
Their usage is covered in the guide
|
||||
[Training & evaluation with the built-in methods](https://keras.io/guides/training_with_built_in_methods/).
|
||||
|
||||
If you want to customize the learning algorithm of your model while still leveraging
|
||||
the convenience of `fit()`
|
||||
(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and
|
||||
implement your own `train_step()` method, which
|
||||
is called repeatedly during `fit()`.
|
||||
|
||||
Now, if you want very low-level control over training & evaluation, you should write
|
||||
your own training & evaluation loops from scratch. This is what this guide is about.
|
||||
"""
|
||||
|
||||
"""
|
||||
## A first end-to-end example
|
||||
|
||||
To write a custom training loop, you need the following ingredients:
|
||||
|
||||
- A model to train, of course.
|
||||
- An optimizer. You could either use a `keras_core.optimizers` optimizer,
|
||||
or a native PyTorch optimizer from `torch.optim`.
|
||||
- A loss function. You could either use a `keras_core.losses` loss,
|
||||
or a native PyTorch loss from `torch.nn`.
|
||||
- A dataset. You could use any format: a `tf.data.Dataset`,
|
||||
a PyTorch `DataLoader`, a Python generator, etc.
|
||||
|
||||
Let's line them up. We'll use torch-native objects in each case --
|
||||
except, of course, for the Keras model.
|
||||
|
||||
First, let's get the model and the MNIST dataset:
|
||||
"""
|
||||
|
||||
|
||||
# Let's consider a simple MNIST model
|
||||
def get_model():
|
||||
inputs = keras.Input(shape=(784,), name="digits")
|
||||
x1 = keras.layers.Dense(64, activation="relu")(inputs)
|
||||
x2 = keras.layers.Dense(64, activation="relu")(x1)
|
||||
outputs = keras.layers.Dense(10, name="predictions")(x2)
|
||||
model = keras.Model(inputs=inputs, outputs=outputs)
|
||||
return model
|
||||
|
||||
|
||||
# Create load up the MNIST dataset and put it in a torch DataLoader
|
||||
# Prepare the training dataset.
|
||||
batch_size = 32
|
||||
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
|
||||
x_train = np.reshape(x_train, (-1, 784))
|
||||
x_test = np.reshape(x_test, (-1, 784))
|
||||
y_train = keras.utils.to_categorical(y_train)
|
||||
y_test = keras.utils.to_categorical(y_test)
|
||||
|
||||
# Reserve 10,000 samples for validation.
|
||||
x_val = x_train[-10000:]
|
||||
y_val = y_train[-10000:]
|
||||
x_train = x_train[:-10000]
|
||||
y_train = y_train[:-10000]
|
||||
|
||||
# Create torch Datasets
|
||||
train_dataset = torch.utils.data.TensorDataset(
|
||||
torch.from_numpy(x_train), torch.from_numpy(y_train)
|
||||
)
|
||||
val_dataset = torch.utils.data.TensorDataset(
|
||||
torch.from_numpy(x_val), torch.from_numpy(y_val)
|
||||
)
|
||||
|
||||
# Create DataLoaders for the Datasets
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True
|
||||
)
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset, batch_size=batch_size, shuffle=False
|
||||
)
|
||||
|
||||
"""
|
||||
Next, here's our PyTorch optimizer and our PyTorch loss function:
|
||||
"""
|
||||
|
||||
# Instantiate a torch optimizer
|
||||
model = get_model()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
|
||||
# Instantiate a torch loss function
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
|
||||
"""
|
||||
Let's train our model using mini-batch gradient with a custom training loop.
|
||||
|
||||
Calling `loss.backward()` on a loss tensor triggers backpropagation.
|
||||
Once that's done, your optimizer is magically aware of the gradients for each variable
|
||||
and can update its variables, which is done via `optimizer.step()`.
|
||||
Tensors, variables, optimizers are all interconnected to one another via hidden global state.
|
||||
Also, don't forget to call `model.zero_grad()` before `loss.backward()`, or you won't
|
||||
get the right gradients for your variables.
|
||||
|
||||
Here's our training loop, step by step:
|
||||
|
||||
- We open a `for` loop that iterates over epochs
|
||||
- For each epoch, we open a `for` loop that iterates over the dataset, in batches
|
||||
- For each batch, we call the model on the input data to retrive the predictions,
|
||||
then we use them to compute a loss value
|
||||
- We call `loss.backward()` to
|
||||
- Outside the scope, we retrieve the gradients of the weights
|
||||
of the model with regard to the loss
|
||||
- Finally, we use the optimizer to update the weights of the model based on the
|
||||
gradients
|
||||
"""
|
||||
|
||||
epochs = 3
|
||||
for epoch in range(epochs):
|
||||
for step, (inputs, targets) in enumerate(train_dataloader):
|
||||
# Forward pass
|
||||
logits = model(inputs)
|
||||
loss = loss_fn(logits, targets)
|
||||
|
||||
# Backward pass
|
||||
model.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Optimizer variable updates
|
||||
optimizer.step()
|
||||
|
||||
# Log every 200 batches.
|
||||
if step % 200 == 0:
|
||||
print(
|
||||
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
|
||||
)
|
||||
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||
|
||||
"""
|
||||
As an alternative, let's look at what the loop looks like when using a Keras optimizer
|
||||
and a Keras loss function.
|
||||
|
||||
Important differences:
|
||||
|
||||
- You retrieve the gradients for the variables via `v.value.grad`,
|
||||
called on each trainable variable.
|
||||
- You update your variables via `optimizer.apply_gradients()`, which must be
|
||||
called in a `torch.no_grad()` scope.
|
||||
|
||||
**Also, a big gotcha:** while all NumPy/TensorFlow/JAX/Keras APIs
|
||||
as well as Python `unittest` APIs use the argument order convention
|
||||
`fn(y_true, y_pred)` (reference values first, predicted values second),
|
||||
PyTorch actually uses `fn(y_pred, y_true)` for its losses.
|
||||
So make sure to invert the order of `logits` and `targets`.
|
||||
"""
|
||||
|
||||
model = get_model()
|
||||
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
|
||||
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
|
||||
|
||||
for epoch in range(epochs):
|
||||
print(f"\nStart of epoch {epoch}")
|
||||
for step, (inputs, targets) in enumerate(train_dataloader):
|
||||
# Forward pass
|
||||
logits = model(inputs)
|
||||
loss = loss_fn(targets, logits)
|
||||
|
||||
# Backward pass
|
||||
model.zero_grad()
|
||||
trainable_weights = [v for v in model.trainable_weights]
|
||||
|
||||
# Call torch.Tensor.backward() on the loss to compute gradients
|
||||
# for the weights.
|
||||
loss.backward()
|
||||
gradients = [v.value.grad for v in trainable_weights]
|
||||
|
||||
# Update weights
|
||||
with torch.no_grad():
|
||||
optimizer.apply_gradients(zip(gradients, trainable_weights))
|
||||
|
||||
# Log every 200 batches.
|
||||
if step % 200 == 0:
|
||||
print(
|
||||
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
|
||||
)
|
||||
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||
|
||||
"""
|
||||
## Low-level handling of metrics
|
||||
|
||||
Let's add metrics monitoring to this basic training loop.
|
||||
|
||||
You can readily reuse built-in Keras metrics (or custom ones you wrote) in such training
|
||||
loops written from scratch. Here's the flow:
|
||||
|
||||
- Instantiate the metric at the start of the loop
|
||||
- Call `metric.update_state()` after each batch
|
||||
- Call `metric.result()` when you need to display the current value of the metric
|
||||
- Call `metric.reset_state()` when you need to clear the state of the metric
|
||||
(typically at the end of an epoch)
|
||||
|
||||
Let's use this knowledge to compute `CategoricalAccuracy` on validation data at
|
||||
the end of each epoch:
|
||||
"""
|
||||
|
||||
# Get a fresh model
|
||||
model = get_model()
|
||||
|
||||
# Instantiate an optimizer to train the model.
|
||||
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
|
||||
# Instantiate a loss function.
|
||||
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
|
||||
|
||||
# Prepare the metrics.
|
||||
train_acc_metric = keras.metrics.CategoricalAccuracy()
|
||||
val_acc_metric = keras.metrics.CategoricalAccuracy()
|
||||
|
||||
"""
|
||||
Here's our training & evaluation loop:
|
||||
"""
|
||||
|
||||
for epoch in range(epochs):
|
||||
print(f"\nStart of epoch {epoch}")
|
||||
for step, (inputs, targets) in enumerate(train_dataloader):
|
||||
# Forward pass
|
||||
logits = model(inputs)
|
||||
loss = loss_fn(targets, logits)
|
||||
|
||||
# Backward pass
|
||||
model.zero_grad()
|
||||
trainable_weights = [v for v in model.trainable_weights]
|
||||
|
||||
# Call torch.Tensor.backward() on the loss to compute gradients
|
||||
# for the weights.
|
||||
loss.backward()
|
||||
gradients = [v.value.grad for v in trainable_weights]
|
||||
|
||||
# Update weights
|
||||
with torch.no_grad():
|
||||
optimizer.apply_gradients(zip(gradients, trainable_weights))
|
||||
|
||||
# Update training metric.
|
||||
train_acc_metric.update_state(targets, logits)
|
||||
|
||||
# Log every 200 batches.
|
||||
if step % 200 == 0:
|
||||
print(
|
||||
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
|
||||
)
|
||||
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||
|
||||
# Display metrics at the end of each epoch.
|
||||
train_acc = train_acc_metric.result()
|
||||
print(f"Training acc over epoch: {float(train_acc):.4f}")
|
||||
|
||||
# Reset training metrics at the end of each epoch
|
||||
train_acc_metric.reset_state()
|
||||
|
||||
# Run a validation loop at the end of each epoch.
|
||||
for x_batch_val, y_batch_val in val_dataloader:
|
||||
val_logits = model(x_batch_val, training=False)
|
||||
# Update val metrics
|
||||
val_acc_metric.update_state(y_batch_val, val_logits)
|
||||
val_acc = val_acc_metric.result()
|
||||
val_acc_metric.reset_state()
|
||||
print(f"Validation acc: {float(val_acc):.4f}")
|
||||
|
||||
|
||||
"""
|
||||
## Low-level handling of losses tracked by the model
|
||||
|
||||
Layers & models recursively track any losses created during the forward pass
|
||||
by layers that call `self.add_loss(value)`. The resulting list of scalar loss
|
||||
values are available via the property `model.losses`
|
||||
at the end of the forward pass.
|
||||
|
||||
If you want to be using these loss components, you should sum them
|
||||
and add them to the main loss in your training step.
|
||||
|
||||
Consider this layer, that creates an activity regularization loss:
|
||||
"""
|
||||
|
||||
|
||||
class ActivityRegularizationLayer(keras.layers.Layer):
|
||||
def call(self, inputs):
|
||||
self.add_loss(1e-2 * torch.sum(inputs))
|
||||
return inputs
|
||||
|
||||
|
||||
"""
|
||||
Let's build a really simple model that uses it:
|
||||
"""
|
||||
|
||||
inputs = keras.Input(shape=(784,), name="digits")
|
||||
x = keras.layers.Dense(64, activation="relu")(inputs)
|
||||
# Insert activity regularization as a layer
|
||||
x = ActivityRegularizationLayer()(x)
|
||||
x = keras.layers.Dense(64, activation="relu")(x)
|
||||
outputs = keras.layers.Dense(10, name="predictions")(x)
|
||||
|
||||
model = keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
||||
"""
|
||||
Here's what our training loop should look like now:
|
||||
"""
|
||||
|
||||
# Get a fresh model
|
||||
model = get_model()
|
||||
|
||||
# Instantiate an optimizer to train the model.
|
||||
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
|
||||
# Instantiate a loss function.
|
||||
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
|
||||
|
||||
# Prepare the metrics.
|
||||
train_acc_metric = keras.metrics.CategoricalAccuracy()
|
||||
val_acc_metric = keras.metrics.CategoricalAccuracy()
|
||||
|
||||
for epoch in range(epochs):
|
||||
print(f"\nStart of epoch {epoch}")
|
||||
for step, (inputs, targets) in enumerate(train_dataloader):
|
||||
# Forward pass
|
||||
logits = model(inputs)
|
||||
loss = loss_fn(targets, logits)
|
||||
if model.losses:
|
||||
loss = loss + torch.sum(*model.losses)
|
||||
|
||||
# Backward pass
|
||||
model.zero_grad()
|
||||
trainable_weights = [v for v in model.trainable_weights]
|
||||
|
||||
# Call torch.Tensor.backward() on the loss to compute gradients
|
||||
# for the weights.
|
||||
loss.backward()
|
||||
gradients = [v.value.grad for v in trainable_weights]
|
||||
|
||||
# Update weights
|
||||
with torch.no_grad():
|
||||
optimizer.apply_gradients(zip(gradients, trainable_weights))
|
||||
|
||||
# Update training metric.
|
||||
train_acc_metric.update_state(targets, logits)
|
||||
|
||||
# Log every 200 batches.
|
||||
if step % 200 == 0:
|
||||
print(
|
||||
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
|
||||
)
|
||||
print(f"Seen so far: {(step + 1) * batch_size} samples")
|
||||
|
||||
# Display metrics at the end of each epoch.
|
||||
train_acc = train_acc_metric.result()
|
||||
print(f"Training acc over epoch: {float(train_acc):.4f}")
|
||||
|
||||
# Reset training metrics at the end of each epoch
|
||||
train_acc_metric.reset_state()
|
||||
|
||||
# Run a validation loop at the end of each epoch.
|
||||
for x_batch_val, y_batch_val in val_dataloader:
|
||||
val_logits = model(x_batch_val, training=False)
|
||||
# Update val metrics
|
||||
val_acc_metric.update_state(y_batch_val, val_logits)
|
||||
val_acc = val_acc_metric.result()
|
||||
val_acc_metric.reset_state()
|
||||
print(f"Validation acc: {float(val_acc):.4f}")
|
@ -2,7 +2,7 @@
|
||||
Title: Writing your own callbacks
|
||||
Authors: Rick Chao, Francois Chollet
|
||||
Date created: 2019/03/20
|
||||
Last modified: 2020/07/12
|
||||
Last modified: 2023/06/25
|
||||
Description: Complete guide to writing new Keras callbacks.
|
||||
Accelerator: GPU
|
||||
"""
|
||||
|
@ -28,7 +28,7 @@ class Variable(
|
||||
)
|
||||
|
||||
def _direct_assign(self, value):
|
||||
self._value.assign(tf.cast(value, self._value.dtype))
|
||||
self.value.assign(value)
|
||||
|
||||
def _convert_to_tensor(self, value, dtype=None):
|
||||
return convert_to_tensor(value, dtype=dtype)
|
||||
|
@ -129,9 +129,9 @@ def convert_to_tensor(x, dtype=None):
|
||||
# Convert to np in case of any array-like that is not list or tuple.
|
||||
if not isinstance(x, (list, tuple)):
|
||||
x = np.array(x)
|
||||
elif len(x) > 0 and any(isinstance(x1, torch.Tensor) for x1 in x):
|
||||
elif len(x) > 0 and isinstance(x[0], torch.Tensor):
|
||||
# Handle list or tuple of torch tensors
|
||||
return torch.stack([convert_to_tensor(x1) for x1 in x])
|
||||
return torch.stack(x)
|
||||
if isinstance(x, np.ndarray) and x.dtype == np.uint32:
|
||||
# Torch backend does not support uint32.
|
||||
x = x.astype(np.int64)
|
||||
|
@ -239,9 +239,5 @@ class CoreOpsCorrectnessTest(testing.TestCase):
|
||||
self.assertAllEqual(x, (1, 1))
|
||||
self.assertIsInstance(x, np.ndarray)
|
||||
|
||||
# Partially converted.
|
||||
x = ops.convert_to_tensor((1, ops.array(2), 3))
|
||||
self.assertAllEqual(x, (1, 2, 3))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
ops.convert_to_numpy(KerasTensor((2,)))
|
||||
|
@ -5,6 +5,7 @@ import inspect
|
||||
import types
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
@ -161,7 +162,11 @@ def serialize_keras_object(obj):
|
||||
}
|
||||
if isinstance(obj, tf.TensorShape):
|
||||
return obj.as_list() if obj._dims is not None else None
|
||||
if backend.is_tensor(obj):
|
||||
if isinstance(obj, (tf.Tensor, jax.numpy.ndarray)) or hasattr(
|
||||
obj, "device"
|
||||
):
|
||||
# Import torch creates circular dependency, so we use
|
||||
# `hasattr(obj, "device")` to check if obj is a torch tensor.
|
||||
return {
|
||||
"class_name": "__tensor__",
|
||||
"config": {
|
||||
|
@ -27,7 +27,11 @@ class ArrayDataAdapter(DataAdapter):
|
||||
shuffle=False,
|
||||
class_weight=None,
|
||||
):
|
||||
if not can_convert_arrays((x, y, sample_weight)):
|
||||
types_struct = nest.map_structure(lambda x: type(x), x)
|
||||
flat_types = nest.flatten(types_struct)
|
||||
if not all(
|
||||
issubclass(c, data_adapter_utils.ARRAY_TYPES) for c in flat_types
|
||||
):
|
||||
raise ValueError(
|
||||
"Expected all elements of `x` to be array-like. "
|
||||
f"Received invalid types: x={x}"
|
||||
@ -248,28 +252,6 @@ class ArrayDataAdapter(DataAdapter):
|
||||
return self._partial_batch_size or None
|
||||
|
||||
|
||||
def can_convert_arrays(arrays):
|
||||
"""Check if array like-inputs can be handled by `ArrayDataAdapter`
|
||||
|
||||
Args:
|
||||
inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like.
|
||||
|
||||
Returns:
|
||||
`True` if `arrays` can be handled by `ArrayDataAdapter`, `False`
|
||||
otherwise.
|
||||
"""
|
||||
|
||||
def can_convert_single_array(x):
|
||||
is_none = x is None
|
||||
known_type = isinstance(x, data_adapter_utils.ARRAY_TYPES)
|
||||
convertable_type = hasattr(x, "__array__")
|
||||
return is_none or known_type or convertable_type
|
||||
|
||||
return all(
|
||||
tf.nest.flatten(tf.nest.map_structure(can_convert_single_array, arrays))
|
||||
)
|
||||
|
||||
|
||||
def convert_to_arrays(arrays, dtype=None):
|
||||
"""Process array-like inputs.
|
||||
|
||||
@ -280,7 +262,7 @@ def convert_to_arrays(arrays, dtype=None):
|
||||
- Converts `list`s to `tuple`s (for `tf.data` support).
|
||||
|
||||
Args:
|
||||
inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like.
|
||||
inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like.
|
||||
|
||||
Returns:
|
||||
Structure of NumPy `ndarray`s.
|
||||
@ -295,16 +277,15 @@ def convert_to_arrays(arrays, dtype=None):
|
||||
x = np.expand_dims(x.to_numpy(dtype=dtype), axis=-1)
|
||||
elif isinstance(x, pandas.DataFrame):
|
||||
x = x.to_numpy(dtype=dtype)
|
||||
if isinstance(x, (tf.Tensor, tf.Variable)):
|
||||
x = x.numpy()
|
||||
if not isinstance(x, np.ndarray):
|
||||
# Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`,
|
||||
# `torch.Tensor`, as well as any other tensor-like object that has
|
||||
# added numpy support.
|
||||
if hasattr(x, "__array__"):
|
||||
x = np.array(x, dtype=dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected a NumPy array, tf.Tensor, jax.np.ndarray, "
|
||||
"torch.Tensor, Pandas Dataframe, or Pandas Series. "
|
||||
"Expected a NumPy array, tf.Tensor, "
|
||||
"Pandas Dataframe, or Pandas Series. "
|
||||
f"Received invalid input: {x} (of type {type(x)})"
|
||||
)
|
||||
if x.dtype == object:
|
||||
|
@ -2,7 +2,6 @@ import jax
|
||||
import numpy as np
|
||||
import pandas
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from absl.testing import parameterized
|
||||
|
||||
from keras_core import backend
|
||||
@ -11,9 +10,7 @@ from keras_core.trainers.data_adapters import array_data_adapter
|
||||
|
||||
|
||||
class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase):
|
||||
@parameterized.parameters(
|
||||
[("np",), ("tf",), ("jax",), ("torch",), ("pandas")]
|
||||
)
|
||||
@parameterized.parameters([("np",), ("tf",), ("pandas")])
|
||||
def test_basic_flow(self, array_type):
|
||||
if array_type == "np":
|
||||
x = np.random.random((34, 4))
|
||||
@ -24,9 +21,6 @@ class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase):
|
||||
elif array_type == "jax":
|
||||
x = jax.numpy.ones((34, 4))
|
||||
y = jax.numpy.ones((34, 2))
|
||||
elif array_type == "torch":
|
||||
x = torch.ones((34, 4))
|
||||
y = torch.ones((34, 2))
|
||||
elif array_type == "pandas":
|
||||
x = pandas.DataFrame(np.random.random((34, 4)))
|
||||
y = pandas.DataFrame(np.random.random((34, 2)))
|
||||
|
@ -1,5 +1,6 @@
|
||||
import math
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
@ -11,12 +12,15 @@ except ImportError:
|
||||
pandas = None
|
||||
|
||||
|
||||
# Leave jax, tf, and torch arrays off this list. Instead we will use
|
||||
# `__array__` to detect these types. Doing so allows us to avoid importing a
|
||||
# backend framework we are not currently using just to do type-checking.
|
||||
ARRAY_TYPES = (np.ndarray,)
|
||||
ARRAY_TYPES = (tf.Tensor, np.ndarray, jax.numpy.ndarray)
|
||||
if pandas:
|
||||
ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame)
|
||||
ARRAY_TYPES = ARRAY_TYPES + (
|
||||
tf.Tensor,
|
||||
np.ndarray,
|
||||
pandas.Series,
|
||||
pandas.DataFrame,
|
||||
)
|
||||
# TODO: support torch tensors?
|
||||
|
||||
|
||||
@keras_core_export("keras_core.utils.unpack_x_y_sample_weight")
|
||||
|
@ -42,8 +42,10 @@ import types
|
||||
import warnings
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import nest
|
||||
|
||||
from keras_core.trainers.data_adapters import array_data_adapter
|
||||
from keras_core.trainers.data_adapters import data_adapter_utils
|
||||
from keras_core.trainers.data_adapters import generator_data_adapter
|
||||
from keras_core.trainers.data_adapters import py_dataset_adapter
|
||||
from keras_core.trainers.data_adapters import tf_dataset_adapter
|
||||
@ -67,7 +69,8 @@ class EpochIterator:
|
||||
if steps_per_epoch:
|
||||
self._current_iterator = None
|
||||
self._insufficient_data = False
|
||||
if array_data_adapter.can_convert_arrays((x, y, sample_weight)):
|
||||
first_element = next(iter(nest.flatten(x)), None)
|
||||
if isinstance(first_element, data_adapter_utils.ARRAY_TYPES):
|
||||
self.data_adapter = array_data_adapter.ArrayDataAdapter(
|
||||
x,
|
||||
y,
|
||||
|
Loading…
Reference in New Issue
Block a user