Fix torch backend random.categorical (#389)
Confusingly, torch.multinomial takes in unnormalized probabilities (they are not even asserted to be positive), rather then logits. Hell if you can tell that from the docs.
This commit is contained in:
parent
664cf6f4f5
commit
556d605214
@ -30,8 +30,9 @@ def categorical(logits, num_samples, dtype="int32", seed=None):
|
||||
logits = convert_to_tensor(logits)
|
||||
dtype = to_torch_dtype(dtype)
|
||||
generator = torch_seed_generator(seed, device=get_device())
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
return torch.multinomial(
|
||||
logits,
|
||||
probs,
|
||||
num_samples,
|
||||
replacement=True,
|
||||
generator=generator,
|
||||
|
@ -48,8 +48,9 @@ class RandomTest(testing.TestCase, parameterized.TestCase):
|
||||
)
|
||||
def test_categorical(self, seed, num_samples, batch_size):
|
||||
np.random.seed(seed)
|
||||
# Definitively favor the batch index.
|
||||
logits = np.eye(batch_size) * 1e9
|
||||
# Create logits that definitely favors the batch index after a softmax
|
||||
# is applied. Without a softmax, this would be close to random.
|
||||
logits = np.eye(batch_size) * 1e5 + 1e6
|
||||
res = random.categorical(logits, num_samples, seed=seed)
|
||||
# Outputs should have shape `(batch_size, num_samples)`, where each
|
||||
# output index matches the batch index.
|
||||
|
Loading…
Reference in New Issue
Block a user