One hot op (#3353)
* One hot op * tf too * Update theano_backend.py * Use built-in theano op * Update theano_backend.py * Add test * Update test_backends.py * Update test_backends.py * Generalize for nD tensors * Fix docstring on TF backend * Update theano_backend.py * Update theano_backend.py
This commit is contained in:
parent
ad3231c29a
commit
984ad34a61
@ -840,6 +840,13 @@ def spatial_3d_padding(x, padding=(1, 1, 1), dim_ordering='th'):
|
||||
def pack(x):
|
||||
return tf.pack(x)
|
||||
|
||||
def one_hot(indices, nb_classes):
|
||||
'''
|
||||
Input: nD integer tensor of shape (batch_size, dim1, dim2, ... dim(n-1))
|
||||
Output: (n + 1)D one hot representation of the input with shape (batch_size, dim1, dim2, ... dim(n-1), nb_classes)
|
||||
'''
|
||||
return tf.one_hot(indices, depth=nb_classes, axis=-1)
|
||||
|
||||
|
||||
# VALUE MANIPULATION
|
||||
|
||||
|
@ -595,6 +595,19 @@ def spatial_3d_padding(x, padding=(1, 1, 1), dim_ordering='th'):
|
||||
def pack(x):
|
||||
return T.stack(*x)
|
||||
|
||||
|
||||
def one_hot(indices, nb_classes):
|
||||
'''
|
||||
Input: nD integer tensor of shape (batch_size, dim1, dim2, ... dim(n-1))
|
||||
Output: (n + 1)D one hot representation of the input with shape (batch_size, dim1, dim2, ... dim(n-1), nb_classes)
|
||||
'''
|
||||
input_shape = tuple((indices.shape[i] for i in range(indices.ndim)))
|
||||
indices = T.flatten(indices)
|
||||
oh = T.extra_ops.to_one_hot(indices, nb_classes)
|
||||
oh = T.reshape(oh, input_shape + (nb_classes,))
|
||||
return oh
|
||||
|
||||
|
||||
# VALUE MANIPULATION
|
||||
|
||||
|
||||
|
@ -580,6 +580,16 @@ class TestBackend(object):
|
||||
assert(np.max(rand) == 1)
|
||||
assert(np.min(rand) == 0)
|
||||
|
||||
def test_one_hot(self):
|
||||
input_length = 10
|
||||
nb_classes = 20
|
||||
batch_size = 30
|
||||
indices = np.random.randint(0, nb_classes, size=(batch_size, input_length))
|
||||
oh = np.eye(nb_classes)[indices]
|
||||
for K in [KTH, KTF]:
|
||||
koh = K.eval(K.one_hot(K.variable(indices, dtype='int32'), nb_classes))
|
||||
assert np.all(koh == oh)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
Loading…
Reference in New Issue
Block a user