diff --git a/keras/preprocessing/sequence.py b/keras/preprocessing/sequence.py index dccb1751d..197288c75 100644 --- a/keras/preprocessing/sequence.py +++ b/keras/preprocessing/sequence.py @@ -4,13 +4,15 @@ import numpy as np import random from six.moves import range -def pad_sequences(sequences, maxlen=None, dtype='int32'): +def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre'): """ Pad each sequence to the same length: the length of the longuest sequence. If maxlen is provided, any sequence longer than maxlen is truncated to maxlen. + + Support post-padding and pre-padding (default). """ lengths = [len(s) for s in sequences] @@ -20,7 +22,10 @@ def pad_sequences(sequences, maxlen=None, dtype='int32'): x = np.zeros((nb_samples, maxlen)).astype(dtype) for idx, s in enumerate(sequences): - x[idx, :lengths[idx]] = s[:maxlen] + if padding == 'post': + x[idx, :lengths[idx]] = s[:maxlen] + else: + x[idx, -min(maxlen, lengths[idx]):] = s[:maxlen] return x