Make pre-padding the default in sequence tensors
This commit is contained in:
parent
4830b4be27
commit
6329378ca3
@ -4,13 +4,15 @@ import numpy as np
|
||||
import random
|
||||
from six.moves import range
|
||||
|
||||
def pad_sequences(sequences, maxlen=None, dtype='int32'):
|
||||
def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre'):
|
||||
"""
|
||||
Pad each sequence to the same length:
|
||||
the length of the longuest sequence.
|
||||
|
||||
If maxlen is provided, any sequence longer
|
||||
than maxlen is truncated to maxlen.
|
||||
|
||||
Support post-padding and pre-padding (default).
|
||||
"""
|
||||
lengths = [len(s) for s in sequences]
|
||||
|
||||
@ -20,7 +22,10 @@ def pad_sequences(sequences, maxlen=None, dtype='int32'):
|
||||
|
||||
x = np.zeros((nb_samples, maxlen)).astype(dtype)
|
||||
for idx, s in enumerate(sequences):
|
||||
x[idx, :lengths[idx]] = s[:maxlen]
|
||||
if padding == 'post':
|
||||
x[idx, :lengths[idx]] = s[:maxlen]
|
||||
else:
|
||||
x[idx, -min(maxlen, lengths[idx]):] = s[:maxlen]
|
||||
return x
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user