Fix a warning on the jax backend (#198)
int64 is not a supported type in jax bye default. Trying to use it gets the following error. UserWarning: Explicitly requested dtype int64 requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. target = jnp.array(target, dtype="int64") We can stick to int32 (which is what will be used anyway).
This commit is contained in:
parent
3b115c1073
commit
2f92f7ce22
@ -493,7 +493,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
|||||||
|
|
||||||
|
|
||||||
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
||||||
target = jnp.array(target, dtype="int64")
|
target = jnp.array(target, dtype="int32")
|
||||||
output = jnp.array(output)
|
output = jnp.array(output)
|
||||||
if len(target.shape) == len(output.shape) and target.shape[-1] == 1:
|
if len(target.shape) == len(output.shape) and target.shape[-1] == 1:
|
||||||
target = jnp.squeeze(target, axis=-1)
|
target = jnp.squeeze(target, axis=-1)
|
||||||
|
Loading…
Reference in New Issue
Block a user