Switch to using take for embedding layers (#229)
If we discover via benchmarking we are missing more compilation edge cases for tf, we can consider a `nn.embedding_lookup` potentially baked into the take op for tf.
This commit is contained in:
parent
7f600f067c
commit
c0fdc93daa
@ -457,6 +457,8 @@ def swapaxes(x, axis1, axis2):
|
||||
|
||||
|
||||
def take(x, indices, axis=None):
|
||||
x = convert_to_tensor(x)
|
||||
indices = convert_to_tensor(indices)
|
||||
return jnp.take(x, indices, axis=axis)
|
||||
|
||||
|
||||
|
@ -90,9 +90,8 @@ class Embedding(Layer):
|
||||
def call(self, inputs):
|
||||
if inputs.dtype != "int32" and inputs.dtype != "int64":
|
||||
inputs = ops.cast(inputs, "int32")
|
||||
one_hot_data = ops.one_hot(inputs, num_classes=self.input_dim)
|
||||
out = ops.matmul(one_hot_data, self.embeddings)
|
||||
return ops.cast(out, dtype=self.compute_dtype)
|
||||
outputs = ops.take(self.embeddings, inputs, axis=0)
|
||||
return ops.cast(outputs, dtype=self.compute_dtype)
|
||||
|
||||
def compute_mask(self, inputs, mask=None):
|
||||
if not self.mask_zero:
|
||||
|
Loading…
Reference in New Issue
Block a user