From 556d6052142b28be5f262a2a5a06c8e838e9912b Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Wed, 21 Jun 2023 19:33:48 -0700 Subject: [PATCH] Fix torch backend random.categorical (#389) Confusingly, torch.multinomial takes in unnormalized probabilities (they are not even asserted to be positive), rather then logits. Hell if you can tell that from the docs. --- keras_core/backend/torch/random.py | 3 ++- keras_core/random/random_test.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/keras_core/backend/torch/random.py b/keras_core/backend/torch/random.py index 77e024411..5b5dfda74 100644 --- a/keras_core/backend/torch/random.py +++ b/keras_core/backend/torch/random.py @@ -30,8 +30,9 @@ def categorical(logits, num_samples, dtype="int32", seed=None): logits = convert_to_tensor(logits) dtype = to_torch_dtype(dtype) generator = torch_seed_generator(seed, device=get_device()) + probs = torch.softmax(logits, dim=-1) return torch.multinomial( - logits, + probs, num_samples, replacement=True, generator=generator, diff --git a/keras_core/random/random_test.py b/keras_core/random/random_test.py index 3267a7a79..dd8b45af0 100644 --- a/keras_core/random/random_test.py +++ b/keras_core/random/random_test.py @@ -48,8 +48,9 @@ class RandomTest(testing.TestCase, parameterized.TestCase): ) def test_categorical(self, seed, num_samples, batch_size): np.random.seed(seed) - # Definitively favor the batch index. - logits = np.eye(batch_size) * 1e9 + # Create logits that definitely favors the batch index after a softmax + # is applied. Without a softmax, this would be close to random. + logits = np.eye(batch_size) * 1e5 + 1e6 res = random.categorical(logits, num_samples, seed=seed) # Outputs should have shape `(batch_size, num_samples)`, where each # output index matches the batch index.