Update Perceiver image classification example (#369)
* added perceiver image classifier * removed tfa dependency and replaced tfa LAMB optimizer and Input names * modify comment in tutorial
This commit is contained in:
parent
8d453471d0
commit
6b0cb5598a
@ -45,7 +45,6 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
import keras_core as keras
|
||||
from keras_core import layers
|
||||
import tensorflow_addons as tfa
|
||||
|
||||
"""
|
||||
## Prepare the data
|
||||
@ -206,9 +205,9 @@ def create_cross_attention_module(
|
||||
):
|
||||
inputs = {
|
||||
# Recieve the latent array as an input of shape [1, latent_dim, projection_dim].
|
||||
"latent_array": layers.Input(shape=(latent_dim, projection_dim)),
|
||||
"latent_array": layers.Input(shape=(latent_dim, projection_dim), name="latent_array"),
|
||||
# Recieve the data_array (encoded image) as an input of shape [batch_size, data_dim, projection_dim].
|
||||
"data_array": layers.Input(shape=(data_dim, projection_dim)),
|
||||
"data_array": layers.Input(shape=(data_dim, projection_dim), name="data_array"),
|
||||
}
|
||||
|
||||
# Apply layer norm to the inputs
|
||||
@ -399,10 +398,10 @@ class Perceiver(keras.Model):
|
||||
|
||||
|
||||
def run_experiment(model):
|
||||
# Create LAMB optimizer with weight decay.
|
||||
optimizer = tfa.optimizers.LAMB(
|
||||
# Create Adam optimizer with weight decay.
|
||||
optimizer = keras.optimizers.Adam(
|
||||
learning_rate=learning_rate,
|
||||
weight_decay_rate=weight_decay,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
|
||||
# Compile the model.
|
||||
|
Loading…
Reference in New Issue
Block a user