diff --git a/keras_core/backend/torch/random.py b/keras_core/backend/torch/random.py index 77e024411..5b5dfda74 100644 --- a/keras_core/backend/torch/random.py +++ b/keras_core/backend/torch/random.py @@ -30,8 +30,9 @@ def categorical(logits, num_samples, dtype="int32", seed=None): logits = convert_to_tensor(logits) dtype = to_torch_dtype(dtype) generator = torch_seed_generator(seed, device=get_device()) + probs = torch.softmax(logits, dim=-1) return torch.multinomial( - logits, + probs, num_samples, replacement=True, generator=generator, diff --git a/keras_core/random/random_test.py b/keras_core/random/random_test.py index 3267a7a79..dd8b45af0 100644 --- a/keras_core/random/random_test.py +++ b/keras_core/random/random_test.py @@ -48,8 +48,9 @@ class RandomTest(testing.TestCase, parameterized.TestCase): ) def test_categorical(self, seed, num_samples, batch_size): np.random.seed(seed) - # Definitively favor the batch index. - logits = np.eye(batch_size) * 1e9 + # Create logits that definitely favors the batch index after a softmax + # is applied. Without a softmax, this would be close to random. + logits = np.eye(batch_size) * 1e5 + 1e6 res = random.categorical(logits, num_samples, seed=seed) # Outputs should have shape `(batch_size, num_samples)`, where each # output index matches the batch index.