Switch to using take for embedding layers (#229)

If we discover via benchmarking we are missing more compilation edge
cases for tf, we can consider a `nn.embedding_lookup` potentially
baked into the take op for tf.
This commit is contained in:
Matt Watson 2023-05-30 16:32:46 -07:00 committed by Francois Chollet
parent 7f600f067c
commit c0fdc93daa
2 changed files with 4 additions and 3 deletions

@ -457,6 +457,8 @@ def swapaxes(x, axis1, axis2):
def take(x, indices, axis=None):
x = convert_to_tensor(x)
indices = convert_to_tensor(indices)
return jnp.take(x, indices, axis=axis)

@ -90,9 +90,8 @@ class Embedding(Layer):
def call(self, inputs):
if inputs.dtype != "int32" and inputs.dtype != "int64":
inputs = ops.cast(inputs, "int32")
one_hot_data = ops.one_hot(inputs, num_classes=self.input_dim)
out = ops.matmul(one_hot_data, self.embeddings)
return ops.cast(out, dtype=self.compute_dtype)
outputs = ops.take(self.embeddings, inputs, axis=0)
return ops.cast(outputs, dtype=self.compute_dtype)
def compute_mask(self, inputs, mask=None):
if not self.mask_zero: