Format code + add tf numpy example

This commit is contained in:
Francois Chollet 2023-06-09 11:49:04 -07:00
parent 7180554f65
commit 1a9e850af2
6 changed files with 401 additions and 18 deletions

@ -50,9 +50,11 @@ model = keras_core.Sequential(
######## Writing a torch training loop for a Keras model ######## ######## Writing a torch training loop for a Keras model ########
################################################################# #################################################################
def get_keras_model(): def get_keras_model():
pass pass
model = get_keras_model() model = get_keras_model()
# Instantiate the torch optimizer # Instantiate the torch optimizer
@ -61,6 +63,7 @@ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Instantiate the torch loss function # Instantiate the torch loss function
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
def train_step(data): def train_step(data):
x, y = data x, y = data
y_pred = model(x) y_pred = model(x)
@ -70,6 +73,7 @@ def train_step(data):
optimizer.step() optimizer.step()
return loss return loss
# Create a TensorDataset # Create a TensorDataset
dataset = torch.utils.data.TensorDataset( dataset = torch.utils.data.TensorDataset(
torch.from_numpy(x_train), torch.from_numpy(y_train) torch.from_numpy(x_train), torch.from_numpy(y_train)
@ -122,6 +126,7 @@ train(model, train_loader, num_epochs, optimizer, loss_fn)
######## Using a Keras model or layer in a torch Module ######## ######## Using a Keras model or layer in a torch Module ########
################################################################ ################################################################
class MyModel(nn.Module): class MyModel(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

@ -24,6 +24,7 @@ learning_rate = 0.01
batch_size = 128 batch_size = 128
num_epochs = 1 num_epochs = 1
def get_data(): def get_data():
# Load the data and split it between train and test sets # Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras_core.datasets.mnist.load_data() (x_train, y_train), (x_test, y_test) = keras_core.datasets.mnist.load_data()
@ -44,6 +45,7 @@ def get_data():
) )
return dataset return dataset
def get_model(): def get_model():
# Create the Keras model # Create the Keras model
model = keras_core.Sequential( model = keras_core.Sequential(
@ -108,16 +110,17 @@ def train(model, train_loader, num_epochs, optimizer, loss_fn):
) )
running_loss = 0.0 running_loss = 0.0
def setup(current_gpu_index, num_gpu): def setup(current_gpu_index, num_gpu):
# Device setup # Device setup
os.environ['MASTER_ADDR'] = 'keras-core-torch' os.environ["MASTER_ADDR"] = "keras-core-torch"
os.environ['MASTER_PORT'] = '56492' os.environ["MASTER_PORT"] = "56492"
device = torch.device("cuda:{}".format(current_gpu_index)) device = torch.device("cuda:{}".format(current_gpu_index))
dist.init_process_group( dist.init_process_group(
backend='nccl', backend="nccl",
init_method='env://', init_method="env://",
world_size=num_gpu, world_size=num_gpu,
rank=current_gpu_index rank=current_gpu_index,
) )
torch.cuda.set_device(device) torch.cuda.set_device(device)
@ -140,6 +143,7 @@ def prepare(dataset, current_gpu_index, num_gpu, batch_size):
return train_loader return train_loader
def cleanup(): def cleanup():
# Cleanup # Cleanup
dist.destroy_process_group() dist.destroy_process_group()
@ -167,7 +171,9 @@ def main(current_gpu_index, num_gpu):
# Put model on device # Put model on device
model = model.to(current_gpu_index) model = model.to(current_gpu_index)
ddp_model = DDP(model, device_ids=[current_gpu_index], output_device=current_gpu_index) ddp_model = DDP(
model, device_ids=[current_gpu_index], output_device=current_gpu_index
)
train(ddp_model, dataloader, num_epochs, optimizer, loss_fn) train(ddp_model, dataloader, num_epochs, optimizer, loss_fn)
@ -176,7 +182,11 @@ def main(current_gpu_index, num_gpu):
################################################################ ################################################################
torch_module = MyModel().to(current_gpu_index) torch_module = MyModel().to(current_gpu_index)
ddp_torch_module = DDP(torch_module, device_ids=[current_gpu_index], output_device=current_gpu_index) ddp_torch_module = DDP(
torch_module,
device_ids=[current_gpu_index],
output_device=current_gpu_index,
)
# Instantiate the torch optimizer # Instantiate the torch optimizer
optimizer = optim.Adam(torch_module.parameters(), lr=learning_rate) optimizer = optim.Adam(torch_module.parameters(), lr=learning_rate)
@ -192,12 +202,12 @@ def main(current_gpu_index, num_gpu):
if __name__ == "__main__": if __name__ == "__main__":
# GPU parameters # GPU parameters
num_gpu = torch.cuda.device_count() num_gpu = torch.cuda.device_count()
print(f"Running on {num_gpu} GPUs") print(f"Running on {num_gpu} GPUs")
torch.multiprocessing.spawn( torch.multiprocessing.spawn(
main, main,
args=(num_gpu, ), args=(num_gpu,),
nprocs=num_gpu, nprocs=num_gpu,
join=True, join=True,
) )

@ -54,5 +54,9 @@ You can use the trained model hosted on [Hugging Face Hub](https://huggingface.c
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/bidirectional_lstm_imdb). and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/bidirectional_lstm_imdb).
""" """
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]) model.compile(
model.fit(x_train, y_train, batch_size=32, epochs=2, validation_data=(x_val, y_val)) optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
)
model.fit(
x_train, y_train, batch_size=32, epochs=2, validation_data=(x_val, y_val)
)

@ -0,0 +1,350 @@
"""
Title: Writing Keras Models With TensorFlow NumPy
Author: [lukewood](https://lukewood.xyz)
Date created: 2021/08/28
Last modified: 2021/08/28
Description: Overview of how to use the TensorFlow NumPy API to write Keras models.
Accelerator: GPU
"""
"""
## Introduction
[NumPy](https://numpy.org/) is a hugely successful Python linear algebra library.
TensorFlow recently launched [tf_numpy](https://www.tensorflow.org/guide/tf_numpy), a
TensorFlow implementation of a large subset of the NumPy API.
Thanks to `tf_numpy`, you can write Keras layers or models in the NumPy style!
The TensorFlow NumPy API has full integration with the TensorFlow ecosystem.
Features such as automatic differentiation, TensorBoard, Keras model callbacks,
TPU distribution and model exporting are all supported.
Let's run through a few examples.
"""
"""
## Setup
TensorFlow NumPy requires TensorFlow 2.5 or later.
"""
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
import keras_core as keras
from keras_core import layers
"""
Optionally, you can call `tnp.experimental_enable_numpy_behavior()` to enable type promotion in TensorFlow.
This allows TNP to more closely follow the NumPy standard.
"""
tnp.experimental_enable_numpy_behavior()
"""
To test our models we will use the Boston housing prices regression dataset.
"""
(x_train, y_train), (x_test, y_test) = keras.datasets.boston_housing.load_data(
path="boston_housing.npz", test_split=0.2, seed=113
)
input_dim = x_train.shape[1]
def evaluate_model(model: keras.Model):
[loss, percent_error] = model.evaluate(x_test, y_test, verbose=0)
print("Mean absolute percent error before training: ", percent_error)
model.fit(x_train, y_train, epochs=200, verbose=0)
[loss, percent_error] = model.evaluate(x_test, y_test, verbose=0)
print("Mean absolute percent error after training:", percent_error)
"""
## Subclassing keras.Model with TNP
The most flexible way to make use of the Keras API is to subclass the
[`keras.Model`](https://keras.io/api/models/model/) class. Subclassing the Model class
gives you the ability to fully customize what occurs in the training loop. This makes
subclassing Model a popular option for researchers.
In this example, we will implement a `Model` subclass that performs regression over the
boston housing dataset using the TNP API. Note that differentiation and gradient
descent is handled automatically when using the TNP API alongside keras.
First let's define a simple `TNPForwardFeedRegressionNetwork` class.
"""
class TNPForwardFeedRegressionNetwork(keras.Model):
def __init__(self, blocks=None, **kwargs):
super().__init__(**kwargs)
if not isinstance(blocks, list):
raise ValueError(f"blocks must be a list, got blocks={blocks}")
self.blocks = blocks
self.block_weights = None
self.biases = None
def build(self, input_shape):
current_shape = input_shape[1]
self.block_weights = []
self.biases = []
for i, block in enumerate(self.blocks):
self.block_weights.append(
self.add_weight(
shape=(current_shape, block),
trainable=True,
name=f"block-{i}",
initializer="glorot_normal",
)
)
self.biases.append(
self.add_weight(
shape=(block,),
trainable=True,
name=f"bias-{i}",
initializer="zeros",
)
)
current_shape = block
self.linear_layer = self.add_weight(
shape=(current_shape, 1),
name="linear_projector",
trainable=True,
initializer="glorot_normal",
)
def call(self, inputs):
activations = inputs
for w, b in zip(self.block_weights, self.biases):
activations = tnp.matmul(activations, w) + b
# ReLu activation function
activations = tnp.maximum(activations, 0.0)
return tnp.matmul(activations, self.linear_layer)
"""
Just like with any other Keras model we can utilize any supported optimizer, loss,
metrics or callbacks that we want.
Let's see how the model performs!
"""
model = TNPForwardFeedRegressionNetwork(blocks=[3, 3])
model.compile(
optimizer="adam",
loss="mean_squared_error",
metrics=[keras.metrics.MeanAbsolutePercentageError()],
)
evaluate_model(model)
"""
Great! Our model seems to be effectively learning to solve the problem at hand.
We can also write our own custom loss function using TNP.
"""
def tnp_mse(y_true, y_pred):
return tnp.mean(tnp.square(y_true - y_pred), axis=0)
keras.backend.clear_session()
model = TNPForwardFeedRegressionNetwork(blocks=[3, 3])
model.compile(
optimizer="adam",
loss=tnp_mse,
metrics=[keras.metrics.MeanAbsolutePercentageError()],
)
evaluate_model(model)
"""
## Implementing a Keras Layer Based Model with TNP
If desired, TNP can also be used in layer oriented Keras code structure. Let's
implement the same model, but using a layered approach!
"""
def tnp_relu(x):
return tnp.maximum(x, 0)
class TNPDense(keras.layers.Layer):
def __init__(self, units, activation=None):
super().__init__()
self.units = units
self.activation = activation
def build(self, input_shape):
self.w = self.add_weight(
name="weights",
shape=(input_shape[1], self.units),
initializer="random_normal",
trainable=True,
)
self.bias = self.add_weight(
name="bias",
shape=(self.units,),
initializer="zeros",
trainable=True,
)
def call(self, inputs):
outputs = tnp.matmul(inputs, self.w) + self.bias
if self.activation:
return self.activation(outputs)
return outputs
def create_layered_tnp_model():
return keras.Sequential(
[
TNPDense(3, activation=tnp_relu),
TNPDense(3, activation=tnp_relu),
TNPDense(1),
]
)
model = create_layered_tnp_model()
model.compile(
optimizer="adam",
loss="mean_squared_error",
metrics=[keras.metrics.MeanAbsolutePercentageError()],
)
model.build((None, input_dim))
model.summary()
evaluate_model(model)
"""
You can also seamlessly switch between TNP layers and native Keras layers!
"""
def create_mixed_model():
return keras.Sequential(
[
TNPDense(3, activation=tnp_relu),
# The model will have no issue using a normal Dense layer
layers.Dense(3, activation="relu"),
# ... or switching back to tnp layers!
TNPDense(1),
]
)
model = create_mixed_model()
model.compile(
optimizer="adam",
loss="mean_squared_error",
metrics=[keras.metrics.MeanAbsolutePercentageError()],
)
model.build((None, input_dim))
model.summary()
evaluate_model(model)
"""
The Keras API offers a wide variety of layers. The ability to use them alongside NumPy
code can be a huge time saver in projects.
"""
"""
## Distribution Strategy
TensorFlow NumPy and Keras integrate with
[TensorFlow Distribution Strategies](https://www.tensorflow.org/guide/distributed_training).
This makes it simple to perform distributed training across multiple GPUs,
or even an entire TPU Pod.
"""
gpus = tf.config.list_logical_devices("GPU")
if gpus:
strategy = tf.distribute.MirroredStrategy(gpus)
else:
# We can fallback to a no-op CPU strategy.
strategy = tf.distribute.get_strategy()
print("Running with strategy:", str(strategy.__class__.__name__))
with strategy.scope():
model = create_layered_tnp_model()
model.compile(
optimizer="adam",
loss="mean_squared_error",
metrics=[keras.metrics.MeanAbsolutePercentageError()],
)
model.build((None, input_dim))
model.summary()
evaluate_model(model)
"""
## TensorBoard Integration
One of the many benefits of using the Keras API is the ability to monitor training
through TensorBoard. Using the TensorFlow NumPy API alongside Keras allows you to easily
leverage TensorBoard.
"""
keras.backend.clear_session()
"""
To load the TensorBoard from a Jupyter notebook, you can run the following magic:
```
%load_ext tensorboard
```
"""
models = [
(
TNPForwardFeedRegressionNetwork(blocks=[3, 3]),
"TNPForwardFeedRegressionNetwork",
),
(create_layered_tnp_model(), "layered_tnp_model"),
(create_mixed_model(), "mixed_model"),
]
for model, model_name in models:
model.compile(
optimizer="adam",
loss="mean_squared_error",
metrics=[keras.metrics.MeanAbsolutePercentageError()],
)
model.fit(
x_train,
y_train,
epochs=200,
verbose=0,
callbacks=[keras.callbacks.TensorBoard(log_dir=f"logs/{model_name}")],
)
"""
To load the TensorBoard from a Jupyter notebook you can use the `%tensorboard` magic:
```
%tensorboard --logdir logs
```
The TensorBoard monitor metrics and examine the training curve.
![Tensorboard training graph](https://i.imgur.com/wsOuFnz.png)
The TensorBoard also allows you to explore the computation graph used in your models.
![Tensorboard graph exploration](https://i.imgur.com/tOrezDL.png)
The ability to introspect into your models can be valuable during debugging.
"""
"""
## Conclusion
Porting existing NumPy code to Keras models using the `tensorflow_numpy` API is easy!
By integrating with Keras you gain the ability to use existing Keras callbacks, metrics
and optimizers, easily distribute your training and use Tensorboard.
Migrating a more complex model, such as a ResNet, to the TensorFlow NumPy API would be a
great follow up learning exercise.
Several open source NumPy ResNet implementations are available online.
"""

@ -70,7 +70,9 @@ def show_collage(examples):
) )
for row_idx in range(num_rows): for row_idx in range(num_rows):
for col_idx in range(num_cols): for col_idx in range(num_cols):
array = (np.array(examples[row_idx, col_idx]) * 255).astype(np.uint8) array = (np.array(examples[row_idx, col_idx]) * 255).astype(
np.uint8
)
collage.paste( collage.paste(
Image.fromarray(array), (col_idx * box_size, row_idx * box_size) Image.fromarray(array), (col_idx * box_size, row_idx * box_size)
) )
@ -127,7 +129,9 @@ class AnchorPositivePairs(keras.utils.PyDataset):
return self.num_batchs return self.num_batchs
def __getitem__(self, _idx): def __getitem__(self, _idx):
x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32) x = np.empty(
(2, num_classes, height_width, height_width, 3), dtype=np.float32
)
for class_idx in range(num_classes): for class_idx in range(num_classes):
examples_for_class = class_idx_to_train_idxs[class_idx] examples_for_class = class_idx_to_train_idxs[class_idx]
anchor_idx = random.choice(examples_for_class) anchor_idx = random.choice(examples_for_class)
@ -206,7 +210,9 @@ this model is intentionally small.
""" """
inputs = layers.Input(shape=(height_width, height_width, 3)) inputs = layers.Input(shape=(height_width, height_width, 3))
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs) x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(
inputs
)
x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x) x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x) x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
x = layers.GlobalAveragePooling2D()(x) x = layers.GlobalAveragePooling2D()(x)
@ -243,7 +249,9 @@ near_neighbours_per_example = 10
embeddings = model.predict(x_test) embeddings = model.predict(x_test)
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings) gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
near_neighbours = np.argsort(gram_matrix.T)[:, -(near_neighbours_per_example + 1) :] near_neighbours = np.argsort(gram_matrix.T)[
:, -(near_neighbours_per_example + 1) :
]
""" """
As a visual check of these embeddings we can build a collage of the near neighbours for 5 As a visual check of these embeddings we can build a collage of the near neighbours for 5
@ -308,6 +316,10 @@ labels = [
"Ship", "Ship",
"Truck", "Truck",
] ]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels) disp = ConfusionMatrixDisplay(
disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical") confusion_matrix=confusion_matrix, display_labels=labels
)
disp.plot(
include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical"
)
plt.show() plt.show()

@ -141,7 +141,9 @@ np.random.RandomState(seed=32).shuffle(negative_images)
negative_dataset = tf.data.Dataset.from_tensor_slices(negative_images) negative_dataset = tf.data.Dataset.from_tensor_slices(negative_images)
negative_dataset = negative_dataset.shuffle(buffer_size=4096) negative_dataset = negative_dataset.shuffle(buffer_size=4096)
dataset = tf.data.Dataset.zip((anchor_dataset, positive_dataset, negative_dataset)) dataset = tf.data.Dataset.zip(
(anchor_dataset, positive_dataset, negative_dataset)
)
dataset = dataset.shuffle(buffer_size=1024) dataset = dataset.shuffle(buffer_size=1024)
dataset = dataset.map(preprocess_triplets) dataset = dataset.map(preprocess_triplets)