Add pad_sequences util.
This commit is contained in:
parent
e15cc82694
commit
e1e9ec5b91
@ -41,6 +41,14 @@ class TestCase(unittest.TestCase):
|
||||
def assertAlmostEqual(self, x1, x2, decimal=3, msg=None):
|
||||
np.testing.assert_almost_equal(x1, x2, decimal=decimal)
|
||||
|
||||
def assertAllEqual(self, x1, x2, msg=None):
|
||||
self.assertEqual(len(x1), len(x2), msg=msg)
|
||||
for e1, e2 in zip(x1, x2):
|
||||
if isinstance(e1, (list, tuple)) or isinstance(e2, (list, tuple)):
|
||||
self.assertAllEqual(e1, e2, msg=msg)
|
||||
else:
|
||||
self.assertEqual(e1, e2, msg=msg)
|
||||
|
||||
def assertLen(self, iterable, expected_len, msg=None):
|
||||
self.assertEqual(len(iterable), expected_len, msg=msg)
|
||||
|
||||
|
139
keras_core/utils/sequence_utils.py
Normal file
139
keras_core/utils/sequence_utils.py
Normal file
@ -0,0 +1,139 @@
|
||||
import numpy as np
|
||||
|
||||
from keras_core.api_export import keras_core_export
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
[
|
||||
"keras_core.utils.pad_sequences",
|
||||
"keras_core.preprocessing.sequence.pad_sequences",
|
||||
]
|
||||
)
|
||||
def pad_sequences(
|
||||
sequences,
|
||||
maxlen=None,
|
||||
dtype="int32",
|
||||
padding="pre",
|
||||
truncating="pre",
|
||||
value=0.0,
|
||||
):
|
||||
"""Pads sequences to the same length.
|
||||
|
||||
This function transforms a list (of length `num_samples`)
|
||||
of sequences (lists of integers)
|
||||
into a 2D NumPy array of shape `(num_samples, num_timesteps)`.
|
||||
`num_timesteps` is either the `maxlen` argument if provided,
|
||||
or the length of the longest sequence in the list.
|
||||
|
||||
Sequences that are shorter than `num_timesteps`
|
||||
are padded with `value` until they are `num_timesteps` long.
|
||||
|
||||
Sequences longer than `num_timesteps` are truncated
|
||||
so that they fit the desired length.
|
||||
|
||||
The position where padding or truncation happens is determined by
|
||||
the arguments `padding` and `truncating`, respectively.
|
||||
Pre-padding or removing values from the beginning of the sequence is the
|
||||
default.
|
||||
|
||||
>>> sequence = [[1], [2, 3], [4, 5, 6]]
|
||||
>>> keras_core.utils.pad_sequences(sequence)
|
||||
array([[0, 0, 1],
|
||||
[0, 2, 3],
|
||||
[4, 5, 6]], dtype=int32)
|
||||
|
||||
>>> keras_core.utils.pad_sequences(sequence, value=-1)
|
||||
array([[-1, -1, 1],
|
||||
[-1, 2, 3],
|
||||
[ 4, 5, 6]], dtype=int32)
|
||||
|
||||
>>> keras_core.utils.pad_sequences(sequence, padding='post')
|
||||
array([[1, 0, 0],
|
||||
[2, 3, 0],
|
||||
[4, 5, 6]], dtype=int32)
|
||||
|
||||
>>> keras_core.utils.pad_sequences(sequence, maxlen=2)
|
||||
array([[0, 1],
|
||||
[2, 3],
|
||||
[5, 6]], dtype=int32)
|
||||
|
||||
Args:
|
||||
sequences: List of sequences (each sequence is a list of integers).
|
||||
maxlen: Optional Int, maximum length of all sequences. If not provided,
|
||||
sequences will be padded to the length of the longest individual
|
||||
sequence.
|
||||
dtype: (Optional, defaults to `"int32"`). Type of the output sequences.
|
||||
To pad sequences with variable length strings, you can use `object`.
|
||||
padding: String, "pre" or "post" (optional, defaults to `"pre"`):
|
||||
pad either before or after each sequence.
|
||||
truncating: String, "pre" or "post" (optional, defaults to `"pre"`):
|
||||
remove values from sequences larger than
|
||||
`maxlen`, either at the beginning or at the end of the sequences.
|
||||
value: Float or String, padding value. (Optional, defaults to 0.)
|
||||
|
||||
Returns:
|
||||
NumPy array with shape `(len(sequences), maxlen)`
|
||||
"""
|
||||
if not hasattr(sequences, "__len__"):
|
||||
raise ValueError("`sequences` must be iterable.")
|
||||
num_samples = len(sequences)
|
||||
|
||||
lengths = []
|
||||
sample_shape = ()
|
||||
flag = True
|
||||
|
||||
# take the sample shape from the first non empty sequence
|
||||
# checking for consistency in the main loop below.
|
||||
|
||||
for x in sequences:
|
||||
try:
|
||||
lengths.append(len(x))
|
||||
if flag and len(x):
|
||||
sample_shape = np.asarray(x).shape[1:]
|
||||
flag = False
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
"`sequences` must be a list of iterables. "
|
||||
f"Found non-iterable: {str(x)}"
|
||||
) from e
|
||||
|
||||
if maxlen is None:
|
||||
maxlen = np.max(lengths)
|
||||
|
||||
is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(
|
||||
dtype, np.unicode_
|
||||
)
|
||||
if isinstance(value, str) and dtype != object and not is_dtype_str:
|
||||
raise ValueError(
|
||||
f"`dtype` {dtype} is not compatible with `value`'s type: "
|
||||
f"{type(value)}\nYou should set `dtype=object` for variable length "
|
||||
"strings."
|
||||
)
|
||||
|
||||
x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
|
||||
for idx, s in enumerate(sequences):
|
||||
if not len(s):
|
||||
continue # empty list/array was found
|
||||
if truncating == "pre":
|
||||
trunc = s[-maxlen:]
|
||||
elif truncating == "post":
|
||||
trunc = s[:maxlen]
|
||||
else:
|
||||
raise ValueError(f'Truncating type "{truncating}" not understood')
|
||||
|
||||
# check `trunc` has expected shape
|
||||
trunc = np.asarray(trunc, dtype=dtype)
|
||||
if trunc.shape[1:] != sample_shape:
|
||||
raise ValueError(
|
||||
f"Shape of sample {trunc.shape[1:]} of sequence at "
|
||||
f"position {idx} is different from expected shape "
|
||||
f"{sample_shape}"
|
||||
)
|
||||
|
||||
if padding == "post":
|
||||
x[idx, : len(trunc)] = trunc
|
||||
elif padding == "pre":
|
||||
x[idx, -len(trunc) :] = trunc
|
||||
else:
|
||||
raise ValueError(f'Padding type "{padding}" not understood')
|
||||
return x
|
102
keras_core/utils/sequence_utils_test.py
Normal file
102
keras_core/utils/sequence_utils_test.py
Normal file
@ -0,0 +1,102 @@
|
||||
from keras_core import testing
|
||||
from keras_core.utils import sequence_utils
|
||||
|
||||
|
||||
class PadSequencesTest(testing.TestCase):
|
||||
def test_pad_sequences(self):
|
||||
a = [[1], [1, 2], [1, 2, 3]]
|
||||
|
||||
# test padding
|
||||
b = sequence_utils.pad_sequences(a, maxlen=3, padding="pre")
|
||||
self.assertAllClose(b, [[0, 0, 1], [0, 1, 2], [1, 2, 3]])
|
||||
b = sequence_utils.pad_sequences(a, maxlen=3, padding="post")
|
||||
self.assertAllClose(b, [[1, 0, 0], [1, 2, 0], [1, 2, 3]])
|
||||
|
||||
# test truncating
|
||||
b = sequence_utils.pad_sequences(a, maxlen=2, truncating="pre")
|
||||
self.assertAllClose(b, [[0, 1], [1, 2], [2, 3]])
|
||||
b = sequence_utils.pad_sequences(a, maxlen=2, truncating="post")
|
||||
self.assertAllClose(b, [[0, 1], [1, 2], [1, 2]])
|
||||
|
||||
# test value
|
||||
b = sequence_utils.pad_sequences(a, maxlen=3, value=1)
|
||||
self.assertAllClose(b, [[1, 1, 1], [1, 1, 2], [1, 2, 3]])
|
||||
|
||||
def test_pad_sequences_str(self):
|
||||
a = [["1"], ["1", "2"], ["1", "2", "3"]]
|
||||
|
||||
# test padding
|
||||
b = sequence_utils.pad_sequences(
|
||||
a, maxlen=3, padding="pre", value="pad", dtype=object
|
||||
)
|
||||
self.assertAllEqual(
|
||||
b, [["pad", "pad", "1"], ["pad", "1", "2"], ["1", "2", "3"]]
|
||||
)
|
||||
b = sequence_utils.pad_sequences(
|
||||
a, maxlen=3, padding="post", value="pad", dtype="<U3"
|
||||
)
|
||||
self.assertAllEqual(
|
||||
b, [["1", "pad", "pad"], ["1", "2", "pad"], ["1", "2", "3"]]
|
||||
)
|
||||
|
||||
# test truncating
|
||||
b = sequence_utils.pad_sequences(
|
||||
a, maxlen=2, truncating="pre", value="pad", dtype=object
|
||||
)
|
||||
self.assertAllEqual(b, [["pad", "1"], ["1", "2"], ["2", "3"]])
|
||||
b = sequence_utils.pad_sequences(
|
||||
a, maxlen=2, truncating="post", value="pad", dtype="<U3"
|
||||
)
|
||||
self.assertAllEqual(b, [["pad", "1"], ["1", "2"], ["1", "2"]])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "`dtype` int32 is not compatible with "
|
||||
):
|
||||
sequence_utils.pad_sequences(
|
||||
a, maxlen=2, truncating="post", value="pad"
|
||||
)
|
||||
|
||||
def test_pad_sequences_vector(self):
|
||||
a = [[[1, 1]], [[2, 1], [2, 2]], [[3, 1], [3, 2], [3, 3]]]
|
||||
|
||||
# test padding
|
||||
b = sequence_utils.pad_sequences(a, maxlen=3, padding="pre")
|
||||
self.assertAllClose(
|
||||
b,
|
||||
[
|
||||
[[0, 0], [0, 0], [1, 1]],
|
||||
[[0, 0], [2, 1], [2, 2]],
|
||||
[[3, 1], [3, 2], [3, 3]],
|
||||
],
|
||||
)
|
||||
b = sequence_utils.pad_sequences(a, maxlen=3, padding="post")
|
||||
self.assertAllClose(
|
||||
b,
|
||||
[
|
||||
[[1, 1], [0, 0], [0, 0]],
|
||||
[[2, 1], [2, 2], [0, 0]],
|
||||
[[3, 1], [3, 2], [3, 3]],
|
||||
],
|
||||
)
|
||||
|
||||
# test truncating
|
||||
b = sequence_utils.pad_sequences(a, maxlen=2, truncating="pre")
|
||||
self.assertAllClose(
|
||||
b, [[[0, 0], [1, 1]], [[2, 1], [2, 2]], [[3, 2], [3, 3]]]
|
||||
)
|
||||
|
||||
b = sequence_utils.pad_sequences(a, maxlen=2, truncating="post")
|
||||
self.assertAllClose(
|
||||
b, [[[0, 0], [1, 1]], [[2, 1], [2, 2]], [[3, 1], [3, 2]]]
|
||||
)
|
||||
|
||||
# test value
|
||||
b = sequence_utils.pad_sequences(a, maxlen=3, value=1)
|
||||
self.assertAllClose(
|
||||
b,
|
||||
[
|
||||
[[1, 1], [1, 1], [1, 1]],
|
||||
[[1, 1], [2, 1], [2, 2]],
|
||||
[[3, 1], [3, 2], [3, 3]],
|
||||
],
|
||||
)
|
Loading…
Reference in New Issue
Block a user