Fix torch backend random.categorical (#389)

Confusingly, torch.multinomial takes in unnormalized probabilities (they
are not even asserted to be positive), rather then logits. Hell if you
can tell that from the docs.
This commit is contained in:
Matt Watson 2023-06-21 19:33:48 -07:00 committed by Francois Chollet
parent 664cf6f4f5
commit 556d605214
2 changed files with 5 additions and 3 deletions

@ -30,8 +30,9 @@ def categorical(logits, num_samples, dtype="int32", seed=None):
logits = convert_to_tensor(logits)
dtype = to_torch_dtype(dtype)
generator = torch_seed_generator(seed, device=get_device())
probs = torch.softmax(logits, dim=-1)
return torch.multinomial(
logits,
probs,
num_samples,
replacement=True,
generator=generator,

@ -48,8 +48,9 @@ class RandomTest(testing.TestCase, parameterized.TestCase):
)
def test_categorical(self, seed, num_samples, batch_size):
np.random.seed(seed)
# Definitively favor the batch index.
logits = np.eye(batch_size) * 1e9
# Create logits that definitely favors the batch index after a softmax
# is applied. Without a softmax, this would be close to random.
logits = np.eye(batch_size) * 1e5 + 1e6
res = random.categorical(logits, num_samples, seed=seed)
# Outputs should have shape `(batch_size, num_samples)`, where each
# output index matches the batch index.