updrage to keras 3.0 for metric learning example (#18701)
* updrage to keras 3.0 for metric learning example * address comments and verify on successful run * update keras backend * update keras backend * update import os * moved changes to tensorflow folder
This commit is contained in:
parent
11cbb29b30
commit
2f825fc7de
@ -23,7 +23,11 @@ For a more detailed overview of metric learning see:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
## Setup
|
## Setup
|
||||||
|
|
||||||
|
Set Keras backend to tensorflow.
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
|
os.environ["KERAS_BACKEND"] = "tensorflow"
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -32,9 +36,8 @@ import tensorflow as tf
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sklearn.metrics import ConfusionMatrixDisplay
|
from sklearn.metrics import ConfusionMatrixDisplay
|
||||||
import keras as keras
|
import keras
|
||||||
from keras import layers
|
from keras import layers
|
||||||
from keras.datasets import cifar10
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
## Dataset
|
## Dataset
|
||||||
@ -43,6 +46,9 @@ For this example we will be using the
|
|||||||
[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.
|
[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from keras.datasets import cifar10
|
||||||
|
|
||||||
|
|
||||||
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
||||||
|
|
||||||
x_train = x_train.astype("float32") / 255.0
|
x_train = x_train.astype("float32") / 255.0
|
||||||
@ -70,9 +76,7 @@ def show_collage(examples):
|
|||||||
)
|
)
|
||||||
for row_idx in range(num_rows):
|
for row_idx in range(num_rows):
|
||||||
for col_idx in range(num_cols):
|
for col_idx in range(num_cols):
|
||||||
array = (np.array(examples[row_idx, col_idx]) * 255).astype(
|
array = (np.array(examples[row_idx, col_idx]) * 255).astype(np.uint8)
|
||||||
np.uint8
|
|
||||||
)
|
|
||||||
collage.paste(
|
collage.paste(
|
||||||
Image.fromarray(array), (col_idx * box_size, row_idx * box_size)
|
Image.fromarray(array), (col_idx * box_size, row_idx * box_size)
|
||||||
)
|
)
|
||||||
@ -120,18 +124,16 @@ CIFAR-10 this is 10.
|
|||||||
num_classes = 10
|
num_classes = 10
|
||||||
|
|
||||||
|
|
||||||
class AnchorPositivePairs(keras.utils.PyDataset):
|
class AnchorPositivePairs(keras.utils.Sequence):
|
||||||
def __init__(self, num_batchs):
|
def __init__(self, num_batches):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_batchs = num_batchs
|
self.num_batches = num_batches
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_batchs
|
return self.num_batches
|
||||||
|
|
||||||
def __getitem__(self, _idx):
|
def __getitem__(self, _idx):
|
||||||
x = np.empty(
|
x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)
|
||||||
(2, num_classes, height_width, height_width, 3), dtype=np.float32
|
|
||||||
)
|
|
||||||
for class_idx in range(num_classes):
|
for class_idx in range(num_classes):
|
||||||
examples_for_class = class_idx_to_train_idxs[class_idx]
|
examples_for_class = class_idx_to_train_idxs[class_idx]
|
||||||
anchor_idx = random.choice(examples_for_class)
|
anchor_idx = random.choice(examples_for_class)
|
||||||
@ -148,7 +150,7 @@ We can visualise a batch in another collage. The top row shows randomly chosen a
|
|||||||
from the 10 classes, the bottom row shows the corresponding 10 positives.
|
from the 10 classes, the bottom row shows the corresponding 10 positives.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
examples = next(iter(AnchorPositivePairs(num_batchs=1)))
|
examples = next(iter(AnchorPositivePairs(num_batches=1)))
|
||||||
|
|
||||||
show_collage(examples)
|
show_collage(examples)
|
||||||
|
|
||||||
@ -174,7 +176,7 @@ class EmbeddingModel(keras.Model):
|
|||||||
|
|
||||||
# Calculate cosine similarity between anchors and positives. As they have
|
# Calculate cosine similarity between anchors and positives. As they have
|
||||||
# been normalised this is just the pair wise dot products.
|
# been normalised this is just the pair wise dot products.
|
||||||
similarities = tf.einsum(
|
similarities = keras.ops.einsum(
|
||||||
"ae,pe->ap", anchor_embeddings, positive_embeddings
|
"ae,pe->ap", anchor_embeddings, positive_embeddings
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -188,16 +190,15 @@ class EmbeddingModel(keras.Model):
|
|||||||
# want the main diagonal values, which correspond to the anchor/positive
|
# want the main diagonal values, which correspond to the anchor/positive
|
||||||
# pairs, to be high. This loss will move embeddings for the
|
# pairs, to be high. This loss will move embeddings for the
|
||||||
# anchor/positive pairs together and move all other pairs apart.
|
# anchor/positive pairs together and move all other pairs apart.
|
||||||
sparse_labels = tf.range(num_classes)
|
sparse_labels = keras.ops.arange(num_classes)
|
||||||
loss = self.compute_loss(y=sparse_labels, y_pred=similarities)
|
loss = self.compiled_loss(sparse_labels, similarities)
|
||||||
|
|
||||||
# Calculate gradients and apply via optimizer.
|
# Calculate gradients and apply via optimizer.
|
||||||
gradients = tape.gradient(loss, self.trainable_variables)
|
gradients = tape.gradient(loss, self.trainable_variables)
|
||||||
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
|
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
|
||||||
|
|
||||||
# Update and return metrics (specifically the one for the loss value).
|
# Update and return metrics (specifically the one for the loss value).
|
||||||
for metric in self.metrics:
|
self.compiled_metrics.update_state(sparse_labels, similarities)
|
||||||
metric.update_state(sparse_labels, similarities)
|
|
||||||
return {m.name: m.result() for m in self.metrics}
|
return {m.name: m.result() for m in self.metrics}
|
||||||
|
|
||||||
|
|
||||||
@ -210,14 +211,12 @@ this model is intentionally small.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
inputs = layers.Input(shape=(height_width, height_width, 3))
|
inputs = layers.Input(shape=(height_width, height_width, 3))
|
||||||
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(
|
x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs)
|
||||||
inputs
|
|
||||||
)
|
|
||||||
x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
|
x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
|
||||||
x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
|
x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
|
||||||
x = layers.GlobalAveragePooling2D()(x)
|
x = layers.GlobalAveragePooling2D()(x)
|
||||||
embeddings = layers.Dense(units=8, activation=None)(x)
|
embeddings = layers.Dense(units=8, activation=None)(x)
|
||||||
embeddings = keras.layers.UnitNormalization()(embeddings)
|
embeddings = layers.UnitNormalization()(embeddings)
|
||||||
|
|
||||||
model = EmbeddingModel(inputs, embeddings)
|
model = EmbeddingModel(inputs, embeddings)
|
||||||
|
|
||||||
@ -230,7 +229,7 @@ model.compile(
|
|||||||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
history = model.fit(AnchorPositivePairs(num_batchs=1000), epochs=20)
|
history = model.fit(AnchorPositivePairs(num_batches=1000), epochs=20)
|
||||||
|
|
||||||
plt.plot(history.history["loss"])
|
plt.plot(history.history["loss"])
|
||||||
plt.show()
|
plt.show()
|
||||||
@ -249,9 +248,7 @@ near_neighbours_per_example = 10
|
|||||||
|
|
||||||
embeddings = model.predict(x_test)
|
embeddings = model.predict(x_test)
|
||||||
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
|
gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
|
||||||
near_neighbours = np.argsort(gram_matrix.T)[
|
near_neighbours = np.argsort(gram_matrix.T)[:, -(near_neighbours_per_example + 1) :]
|
||||||
:, -(near_neighbours_per_example + 1) :
|
|
||||||
]
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
As a visual check of these embeddings we can build a collage of the near neighbours for 5
|
As a visual check of these embeddings we can build a collage of the near neighbours for 5
|
||||||
@ -316,10 +313,6 @@ labels = [
|
|||||||
"Ship",
|
"Ship",
|
||||||
"Truck",
|
"Truck",
|
||||||
]
|
]
|
||||||
disp = ConfusionMatrixDisplay(
|
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
|
||||||
confusion_matrix=confusion_matrix, display_labels=labels
|
disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical")
|
||||||
)
|
plt.show()
|
||||||
disp.plot(
|
|
||||||
include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical"
|
|
||||||
)
|
|
||||||
plt.show()
|
|
Loading…
Reference in New Issue
Block a user