Make pre-padding the default in sequence tensors
This commit is contained in:
parent
4830b4be27
commit
6329378ca3
@ -4,13 +4,15 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
from six.moves import range
|
from six.moves import range
|
||||||
|
|
||||||
def pad_sequences(sequences, maxlen=None, dtype='int32'):
|
def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre'):
|
||||||
"""
|
"""
|
||||||
Pad each sequence to the same length:
|
Pad each sequence to the same length:
|
||||||
the length of the longuest sequence.
|
the length of the longuest sequence.
|
||||||
|
|
||||||
If maxlen is provided, any sequence longer
|
If maxlen is provided, any sequence longer
|
||||||
than maxlen is truncated to maxlen.
|
than maxlen is truncated to maxlen.
|
||||||
|
|
||||||
|
Support post-padding and pre-padding (default).
|
||||||
"""
|
"""
|
||||||
lengths = [len(s) for s in sequences]
|
lengths = [len(s) for s in sequences]
|
||||||
|
|
||||||
@ -20,7 +22,10 @@ def pad_sequences(sequences, maxlen=None, dtype='int32'):
|
|||||||
|
|
||||||
x = np.zeros((nb_samples, maxlen)).astype(dtype)
|
x = np.zeros((nb_samples, maxlen)).astype(dtype)
|
||||||
for idx, s in enumerate(sequences):
|
for idx, s in enumerate(sequences):
|
||||||
x[idx, :lengths[idx]] = s[:maxlen]
|
if padding == 'post':
|
||||||
|
x[idx, :lengths[idx]] = s[:maxlen]
|
||||||
|
else:
|
||||||
|
x[idx, -min(maxlen, lengths[idx]):] = s[:maxlen]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user