From 2f92f7ce22283e40ecdddf7bce3faef651927590 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Fri, 19 May 2023 21:45:22 -0700 Subject: [PATCH] Fix a warning on the jax backend (#198) int64 is not a supported type in jax bye default. Trying to use it gets the following error. UserWarning: Explicitly requested dtype int64 requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. target = jnp.array(target, dtype="int64") We can stick to int32 (which is what will be used anyway). --- keras_core/backend/jax/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_core/backend/jax/nn.py b/keras_core/backend/jax/nn.py index 5ade61409..418f6bdcf 100644 --- a/keras_core/backend/jax/nn.py +++ b/keras_core/backend/jax/nn.py @@ -493,7 +493,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): - target = jnp.array(target, dtype="int64") + target = jnp.array(target, dtype="int32") output = jnp.array(output) if len(target.shape) == len(output.shape) and target.shape[-1] == 1: target = jnp.squeeze(target, axis=-1)