Update Perceiver image classification example (#369)

* added perceiver image classifier

* removed tfa dependency and replaced tfa LAMB optimizer and Input names

* modify comment in tutorial
This commit is contained in:
divyasreepat 2023-06-16 16:13:37 -07:00 committed by Francois Chollet
parent 8d453471d0
commit 6b0cb5598a

@ -45,7 +45,6 @@ import numpy as np
import tensorflow as tf
import keras_core as keras
from keras_core import layers
import tensorflow_addons as tfa
"""
## Prepare the data
@ -206,9 +205,9 @@ def create_cross_attention_module(
):
inputs = {
# Recieve the latent array as an input of shape [1, latent_dim, projection_dim].
"latent_array": layers.Input(shape=(latent_dim, projection_dim)),
"latent_array": layers.Input(shape=(latent_dim, projection_dim), name="latent_array"),
# Recieve the data_array (encoded image) as an input of shape [batch_size, data_dim, projection_dim].
"data_array": layers.Input(shape=(data_dim, projection_dim)),
"data_array": layers.Input(shape=(data_dim, projection_dim), name="data_array"),
}
# Apply layer norm to the inputs
@ -399,10 +398,10 @@ class Perceiver(keras.Model):
def run_experiment(model):
# Create LAMB optimizer with weight decay.
optimizer = tfa.optimizers.LAMB(
# Create Adam optimizer with weight decay.
optimizer = keras.optimizers.Adam(
learning_rate=learning_rate,
weight_decay_rate=weight_decay,
weight_decay=weight_decay,
)
# Compile the model.