One hot op (#3353)

* One hot op

* tf too

* Update theano_backend.py

* Use built-in theano op

* Update theano_backend.py

* Add test

* Update test_backends.py

* Update test_backends.py

* Generalize for nD tensors

* Fix docstring on TF backend

* Update theano_backend.py

* Update theano_backend.py
This commit is contained in:
Fariz Rahman 2016-08-05 02:46:36 +05:30 committed by François Chollet
parent ad3231c29a
commit 984ad34a61
3 changed files with 30 additions and 0 deletions

@ -840,6 +840,13 @@ def spatial_3d_padding(x, padding=(1, 1, 1), dim_ordering='th'):
def pack(x):
return tf.pack(x)
def one_hot(indices, nb_classes):
'''
Input: nD integer tensor of shape (batch_size, dim1, dim2, ... dim(n-1))
Output: (n + 1)D one hot representation of the input with shape (batch_size, dim1, dim2, ... dim(n-1), nb_classes)
'''
return tf.one_hot(indices, depth=nb_classes, axis=-1)
# VALUE MANIPULATION

@ -595,6 +595,19 @@ def spatial_3d_padding(x, padding=(1, 1, 1), dim_ordering='th'):
def pack(x):
return T.stack(*x)
def one_hot(indices, nb_classes):
'''
Input: nD integer tensor of shape (batch_size, dim1, dim2, ... dim(n-1))
Output: (n + 1)D one hot representation of the input with shape (batch_size, dim1, dim2, ... dim(n-1), nb_classes)
'''
input_shape = tuple((indices.shape[i] for i in range(indices.ndim)))
indices = T.flatten(indices)
oh = T.extra_ops.to_one_hot(indices, nb_classes)
oh = T.reshape(oh, input_shape + (nb_classes,))
return oh
# VALUE MANIPULATION

@ -580,6 +580,16 @@ class TestBackend(object):
assert(np.max(rand) == 1)
assert(np.min(rand) == 0)
def test_one_hot(self):
input_length = 10
nb_classes = 20
batch_size = 30
indices = np.random.randint(0, nb_classes, size=(batch_size, input_length))
oh = np.eye(nb_classes)[indices]
for K in [KTH, KTF]:
koh = K.eval(K.one_hot(K.variable(indices, dtype='int32'), nb_classes))
assert np.all(koh == oh)
if __name__ == '__main__':
pytest.main([__file__])