diff --git a/keras_core/backend/jax/numpy.py b/keras_core/backend/jax/numpy.py index a4221bfe6..57f14a0a2 100644 --- a/keras_core/backend/jax/numpy.py +++ b/keras_core/backend/jax/numpy.py @@ -457,6 +457,8 @@ def swapaxes(x, axis1, axis2): def take(x, indices, axis=None): + x = convert_to_tensor(x) + indices = convert_to_tensor(indices) return jnp.take(x, indices, axis=axis) diff --git a/keras_core/layers/core/embedding.py b/keras_core/layers/core/embedding.py index 7e2b17073..535586488 100644 --- a/keras_core/layers/core/embedding.py +++ b/keras_core/layers/core/embedding.py @@ -90,9 +90,8 @@ class Embedding(Layer): def call(self, inputs): if inputs.dtype != "int32" and inputs.dtype != "int64": inputs = ops.cast(inputs, "int32") - one_hot_data = ops.one_hot(inputs, num_classes=self.input_dim) - out = ops.matmul(one_hot_data, self.embeddings) - return ops.cast(out, dtype=self.compute_dtype) + outputs = ops.take(self.embeddings, inputs, axis=0) + return ops.cast(outputs, dtype=self.compute_dtype) def compute_mask(self, inputs, mask=None): if not self.mask_zero: