Fix a warning on the jax backend (#198)

int64 is not a supported type in jax bye default. Trying to use it gets
the following error.

UserWarning: Explicitly requested dtype int64 requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  target = jnp.array(target, dtype="int64")

We can stick to int32 (which is what will be used anyway).
This commit is contained in:
Matt Watson 2023-05-19 21:45:22 -07:00 committed by Francois Chollet
parent 3b115c1073
commit 2f92f7ce22

@ -493,7 +493,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
target = jnp.array(target, dtype="int64")
target = jnp.array(target, dtype="int32")
output = jnp.array(output)
if len(target.shape) == len(output.shape) and target.shape[-1] == 1:
target = jnp.squeeze(target, axis=-1)