updrage to keras 3.0 for metric learning example (#18701)

* updrage to keras 3.0 for metric learning example

* address comments and verify on successful run

* update keras backend

* update keras backend

* update import os

* moved changes to tensorflow folder
This commit is contained in:
madhusshivakumar 2023-11-02 12:43:49 -07:00 committed by GitHub
parent 11cbb29b30
commit 2f825fc7de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,7 +23,11 @@ For a more detailed overview of metric learning see:
"""
## Setup
Set Keras backend to tensorflow.
"""
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import random
import matplotlib.pyplot as plt
@ -32,9 +36,8 @@ import tensorflow as tf
from collections import defaultdict
from PIL import Image
from sklearn.metrics import ConfusionMatrixDisplay
import keras as keras
import keras
from keras import layers
from keras.datasets import cifar10
"""
## Dataset
@ -43,6 +46,9 @@ For this example we will be using the
[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.
"""
from keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
@ -70,9 +76,7 @@ def show_collage(examples):
)
for row_idx in range(num_rows):
for col_idx in range(num_cols):
array = (np.array(examples[row_idx, col_idx]) * 255).astype(
np.uint8
)
array = (np.array(examples[row_idx, col_idx]) * 255).astype(np.uint8)
collage.paste(
Image.fromarray(array), (col_idx * box_size, row_idx * box_size)
)
@ -120,18 +124,16 @@ CIFAR-10 this is 10.
num_classes = 10
class AnchorPositivePairs(keras.utils.PyDataset):
def __init__(self, num_batchs):
class AnchorPositivePairs(keras.utils.Sequence):
def __init__(self, num_batches):
super().__init__()
self.num_batchs = num_batchs
self.num_batches = num_batches
def __len__(self):
return self.num_batchs
return self.num_batches
def __getitem__(self, _idx):
x = np.empty(
(2, num_classes, height_width, height_width, 3), dtype=np.float32
)
x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)
for class_idx in range(num_classes):
examples_for_class = class_idx_to_train_idxs[class_idx]
anchor_idx = random.choice(examples_for_class)
@ -148,7 +150,7 @@ We can visualise a batch in another collage. The top row shows randomly chosen a
from the 10 classes, the bottom row shows the corresponding 10 positives.
"""
examples = next(iter(AnchorPositivePairs(num_batchs=1)))
examples = next(iter(AnchorPositivePairs(num_batches=1)))
show_collage(examples)
@ -174,7 +176,7 @@ class EmbeddingModel(keras.Model):
# Calculate cosine similarity between anchors and positives. As they have
# been normalised this is just the pair wise dot products.
similarities = tf.einsum(
similarities = keras.ops.einsum(
"ae,pe->ap", anchor_embeddings, positive_embeddings
)
@ -188,16 +190,15 @@ class EmbeddingModel(keras.Model):
# want the main diagonal values, which correspond to the anchor/positive
# pairs, to be high. This loss will move embeddings for the
# anchor/positive pairs together and move all other pairs apart.
sparse_labels = tf.range(num_classes)
loss = self.compute_loss(y=sparse_labels, y_pred=similarities)
sparse_labels = keras.ops.arange(num_classes)
loss = self.compiled_loss(sparse_labels, similarities)
# Calculate gradients and apply via optimizer.
gradients = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
# Update and return metrics (specifically the one for the loss value).
for metric in self.metrics:
metric.update_state(sparse_labels, similarities)
self.compiled_metrics.update_state(sparse_labels, similarities)
return {m.name: m.result() for m in self.metrics}
@ -210,14 +211,12 @@ this model is intentionally small.
"""
inputs = layers.Input(shape=(height_width, height_width, 3))
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(
inputs
)
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs)
x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
x = layers.GlobalAveragePooling2D()(x)
embeddings = layers.Dense(units=8, activation=None)(x)
embeddings = keras.layers.UnitNormalization()(embeddings)
embeddings = layers.UnitNormalization()(embeddings)
model = EmbeddingModel(inputs, embeddings)
@ -230,7 +229,7 @@ model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
history = model.fit(AnchorPositivePairs(num_batchs=1000), epochs=20)
history = model.fit(AnchorPositivePairs(num_batches=1000), epochs=20)
plt.plot(history.history["loss"])
plt.show()
@ -249,9 +248,7 @@ near_neighbours_per_example = 10
embeddings = model.predict(x_test)
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
near_neighbours = np.argsort(gram_matrix.T)[
:, -(near_neighbours_per_example + 1) :
]
near_neighbours = np.argsort(gram_matrix.T)[:, -(near_neighbours_per_example + 1) :]
"""
As a visual check of these embeddings we can build a collage of the near neighbours for 5
@ -316,10 +313,6 @@ labels = [
"Ship",
"Truck",
]
disp = ConfusionMatrixDisplay(
confusion_matrix=confusion_matrix, display_labels=labels
)
disp.plot(
include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical"
)
plt.show()
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical")
plt.show()