diff --git a/keras_core/backend/jax/nn.py b/keras_core/backend/jax/nn.py index 5ade61409..418f6bdcf 100644 --- a/keras_core/backend/jax/nn.py +++ b/keras_core/backend/jax/nn.py @@ -493,7 +493,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): - target = jnp.array(target, dtype="int64") + target = jnp.array(target, dtype="int32") output = jnp.array(output) if len(target.shape) == len(output.shape) and target.shape[-1] == 1: target = jnp.squeeze(target, axis=-1)