From 6b0cb5598abbd1df1b82430d1fc1aabc9ea60d41 Mon Sep 17 00:00:00 2001 From: divyasreepat Date: Fri, 16 Jun 2023 16:13:37 -0700 Subject: [PATCH] Update Perceiver image classification example (#369) * added perceiver image classifier * removed tfa dependency and replaced tfa LAMB optimizer and Input names * modify comment in tutorial --- .../vision/perceiver_image_classification.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/keras_io/tensorflow/vision/perceiver_image_classification.py b/examples/keras_io/tensorflow/vision/perceiver_image_classification.py index a33f9c945..a6e223436 100644 --- a/examples/keras_io/tensorflow/vision/perceiver_image_classification.py +++ b/examples/keras_io/tensorflow/vision/perceiver_image_classification.py @@ -45,7 +45,6 @@ import numpy as np import tensorflow as tf import keras_core as keras from keras_core import layers -import tensorflow_addons as tfa """ ## Prepare the data @@ -206,9 +205,9 @@ def create_cross_attention_module( ): inputs = { # Recieve the latent array as an input of shape [1, latent_dim, projection_dim]. - "latent_array": layers.Input(shape=(latent_dim, projection_dim)), + "latent_array": layers.Input(shape=(latent_dim, projection_dim), name="latent_array"), # Recieve the data_array (encoded image) as an input of shape [batch_size, data_dim, projection_dim]. - "data_array": layers.Input(shape=(data_dim, projection_dim)), + "data_array": layers.Input(shape=(data_dim, projection_dim), name="data_array"), } # Apply layer norm to the inputs @@ -399,10 +398,10 @@ class Perceiver(keras.Model): def run_experiment(model): - # Create LAMB optimizer with weight decay. - optimizer = tfa.optimizers.LAMB( + # Create Adam optimizer with weight decay. + optimizer = keras.optimizers.Adam( learning_rate=learning_rate, - weight_decay_rate=weight_decay, + weight_decay=weight_decay, ) # Compile the model.