updrage to keras 3.0 for metric learning example (#18701)
* updrage to keras 3.0 for metric learning example * address comments and verify on successful run * update keras backend * update keras backend * update import os * moved changes to tensorflow folder
This commit is contained in:
parent
11cbb29b30
commit
2f825fc7de
@ -23,7 +23,11 @@ For a more detailed overview of metric learning see:
|
||||
|
||||
"""
|
||||
## Setup
|
||||
|
||||
Set Keras backend to tensorflow.
|
||||
"""
|
||||
import os
|
||||
os.environ["KERAS_BACKEND"] = "tensorflow"
|
||||
|
||||
import random
|
||||
import matplotlib.pyplot as plt
|
||||
@ -32,9 +36,8 @@ import tensorflow as tf
|
||||
from collections import defaultdict
|
||||
from PIL import Image
|
||||
from sklearn.metrics import ConfusionMatrixDisplay
|
||||
import keras as keras
|
||||
import keras
|
||||
from keras import layers
|
||||
from keras.datasets import cifar10
|
||||
|
||||
"""
|
||||
## Dataset
|
||||
@ -43,6 +46,9 @@ For this example we will be using the
|
||||
[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.
|
||||
"""
|
||||
|
||||
from keras.datasets import cifar10
|
||||
|
||||
|
||||
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
||||
|
||||
x_train = x_train.astype("float32") / 255.0
|
||||
@ -70,9 +76,7 @@ def show_collage(examples):
|
||||
)
|
||||
for row_idx in range(num_rows):
|
||||
for col_idx in range(num_cols):
|
||||
array = (np.array(examples[row_idx, col_idx]) * 255).astype(
|
||||
np.uint8
|
||||
)
|
||||
array = (np.array(examples[row_idx, col_idx]) * 255).astype(np.uint8)
|
||||
collage.paste(
|
||||
Image.fromarray(array), (col_idx * box_size, row_idx * box_size)
|
||||
)
|
||||
@ -120,18 +124,16 @@ CIFAR-10 this is 10.
|
||||
num_classes = 10
|
||||
|
||||
|
||||
class AnchorPositivePairs(keras.utils.PyDataset):
|
||||
def __init__(self, num_batchs):
|
||||
class AnchorPositivePairs(keras.utils.Sequence):
|
||||
def __init__(self, num_batches):
|
||||
super().__init__()
|
||||
self.num_batchs = num_batchs
|
||||
self.num_batches = num_batches
|
||||
|
||||
def __len__(self):
|
||||
return self.num_batchs
|
||||
return self.num_batches
|
||||
|
||||
def __getitem__(self, _idx):
|
||||
x = np.empty(
|
||||
(2, num_classes, height_width, height_width, 3), dtype=np.float32
|
||||
)
|
||||
x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)
|
||||
for class_idx in range(num_classes):
|
||||
examples_for_class = class_idx_to_train_idxs[class_idx]
|
||||
anchor_idx = random.choice(examples_for_class)
|
||||
@ -148,7 +150,7 @@ We can visualise a batch in another collage. The top row shows randomly chosen a
|
||||
from the 10 classes, the bottom row shows the corresponding 10 positives.
|
||||
"""
|
||||
|
||||
examples = next(iter(AnchorPositivePairs(num_batchs=1)))
|
||||
examples = next(iter(AnchorPositivePairs(num_batches=1)))
|
||||
|
||||
show_collage(examples)
|
||||
|
||||
@ -174,7 +176,7 @@ class EmbeddingModel(keras.Model):
|
||||
|
||||
# Calculate cosine similarity between anchors and positives. As they have
|
||||
# been normalised this is just the pair wise dot products.
|
||||
similarities = tf.einsum(
|
||||
similarities = keras.ops.einsum(
|
||||
"ae,pe->ap", anchor_embeddings, positive_embeddings
|
||||
)
|
||||
|
||||
@ -188,16 +190,15 @@ class EmbeddingModel(keras.Model):
|
||||
# want the main diagonal values, which correspond to the anchor/positive
|
||||
# pairs, to be high. This loss will move embeddings for the
|
||||
# anchor/positive pairs together and move all other pairs apart.
|
||||
sparse_labels = tf.range(num_classes)
|
||||
loss = self.compute_loss(y=sparse_labels, y_pred=similarities)
|
||||
sparse_labels = keras.ops.arange(num_classes)
|
||||
loss = self.compiled_loss(sparse_labels, similarities)
|
||||
|
||||
# Calculate gradients and apply via optimizer.
|
||||
gradients = tape.gradient(loss, self.trainable_variables)
|
||||
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
|
||||
|
||||
# Update and return metrics (specifically the one for the loss value).
|
||||
for metric in self.metrics:
|
||||
metric.update_state(sparse_labels, similarities)
|
||||
self.compiled_metrics.update_state(sparse_labels, similarities)
|
||||
return {m.name: m.result() for m in self.metrics}
|
||||
|
||||
|
||||
@ -210,14 +211,12 @@ this model is intentionally small.
|
||||
"""
|
||||
|
||||
inputs = layers.Input(shape=(height_width, height_width, 3))
|
||||
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(
|
||||
inputs
|
||||
)
|
||||
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs)
|
||||
x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
|
||||
x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
|
||||
x = layers.GlobalAveragePooling2D()(x)
|
||||
embeddings = layers.Dense(units=8, activation=None)(x)
|
||||
embeddings = keras.layers.UnitNormalization()(embeddings)
|
||||
embeddings = layers.UnitNormalization()(embeddings)
|
||||
|
||||
model = EmbeddingModel(inputs, embeddings)
|
||||
|
||||
@ -230,7 +229,7 @@ model.compile(
|
||||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||||
)
|
||||
|
||||
history = model.fit(AnchorPositivePairs(num_batchs=1000), epochs=20)
|
||||
history = model.fit(AnchorPositivePairs(num_batches=1000), epochs=20)
|
||||
|
||||
plt.plot(history.history["loss"])
|
||||
plt.show()
|
||||
@ -249,9 +248,7 @@ near_neighbours_per_example = 10
|
||||
|
||||
embeddings = model.predict(x_test)
|
||||
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
|
||||
near_neighbours = np.argsort(gram_matrix.T)[
|
||||
:, -(near_neighbours_per_example + 1) :
|
||||
]
|
||||
near_neighbours = np.argsort(gram_matrix.T)[:, -(near_neighbours_per_example + 1) :]
|
||||
|
||||
"""
|
||||
As a visual check of these embeddings we can build a collage of the near neighbours for 5
|
||||
@ -316,10 +313,6 @@ labels = [
|
||||
"Ship",
|
||||
"Truck",
|
||||
]
|
||||
disp = ConfusionMatrixDisplay(
|
||||
confusion_matrix=confusion_matrix, display_labels=labels
|
||||
)
|
||||
disp.plot(
|
||||
include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical"
|
||||
)
|
||||
plt.show()
|
||||
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
|
||||
disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical")
|
||||
plt.show()
|
Loading…
Reference in New Issue
Block a user