Add tests for array data adapter.
This commit is contained in:
parent
5a88d09189
commit
729b60650e
@ -4,6 +4,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import nest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core.trainers.data_adapters import data_adapters_utils
|
||||
from keras_core.trainers.data_adapters.data_adapter import DataAdapter
|
||||
|
||||
@ -13,14 +14,19 @@ except ImportError:
|
||||
pandas = None
|
||||
|
||||
|
||||
ARRAY_TYPES = [tf.Tensor, np.ndarray]
|
||||
ARRAY_TYPES = (tf.Tensor, np.ndarray)
|
||||
if pandas:
|
||||
ARRAY_TYPES.extend([tf.Tensor, np.ndarray, pandas.Series, pandas.DataFrame])
|
||||
ARRAY_TYPES = ARRAY_TYPES + (
|
||||
tf.Tensor,
|
||||
np.ndarray,
|
||||
pandas.Series,
|
||||
pandas.DataFrame,
|
||||
)
|
||||
# TODO: support torch tensors?
|
||||
|
||||
|
||||
class ArrayDataAdapter(DataAdapter):
|
||||
"""Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy."""
|
||||
"""Adapter that handles array-like objects, e.g. tf.Tensor and NumPy arrays."""
|
||||
|
||||
@staticmethod
|
||||
def can_handle(x, y=None):
|
||||
@ -41,15 +47,6 @@ class ArrayDataAdapter(DataAdapter):
|
||||
super().__init__(x, y, **kwargs)
|
||||
x, y, sample_weights = convert_to_arrays((x, y, sample_weights))
|
||||
|
||||
# If sample_weights are not specified for an output, use 1.0 as weights.
|
||||
(
|
||||
sample_weights,
|
||||
_,
|
||||
_,
|
||||
) = data_adapters_utils.handle_partial_sample_weights(
|
||||
y, sample_weights, check_all_flat=True
|
||||
)
|
||||
|
||||
inputs = data_adapters_utils.pack_x_y_sample_weight(
|
||||
x, y, sample_weights
|
||||
)
|
||||
@ -77,11 +74,11 @@ class ArrayDataAdapter(DataAdapter):
|
||||
start, stop = i * self._batch_size, (i + 1) * self._batch_size
|
||||
yield tf.nest.map_structure(lambda x: x[start:stop], self._inputs)
|
||||
|
||||
def get_dataset(self):
|
||||
def get_tf_dataset(self):
|
||||
ds = tf.data.Dataset.from_tensor_slices(self._inputs)
|
||||
ds = ds.shuffle(self._batch_size * 8)
|
||||
ds = ds.batch(self._batch_size)
|
||||
ds = ds.preftech(tf.data.AUTOTUNE)
|
||||
ds = ds.prefetch(tf.data.AUTOTUNE)
|
||||
return ds
|
||||
|
||||
def get_size(self):
|
||||
@ -97,7 +94,7 @@ class ArrayDataAdapter(DataAdapter):
|
||||
return self._partial_batch_size or None
|
||||
|
||||
|
||||
def convert_to_arrays(arrays):
|
||||
def convert_to_arrays(arrays, dtype=None):
|
||||
"""Process array-like inputs.
|
||||
|
||||
This function:
|
||||
@ -110,15 +107,27 @@ def convert_to_arrays(arrays):
|
||||
inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like.
|
||||
|
||||
Returns:
|
||||
Structure of `ndarray`.
|
||||
Structure of NumPy `ndarray`s.
|
||||
"""
|
||||
dtype = dtype or backend.floatx()
|
||||
|
||||
def convert_single_array(x):
|
||||
if x is None:
|
||||
return x
|
||||
if pandas is not None:
|
||||
if isinstance(x, pandas.Series):
|
||||
x = np.expand_dims(x.to_numpy(), axis=-1)
|
||||
x = np.expand_dims(x.to_numpy(dtype=dtype), axis=-1)
|
||||
elif isinstance(x, pandas.DataFrame):
|
||||
x = x.to_numpy(dtype=dtype)
|
||||
if isinstance(x, (tf.Tensor, tf.Variable)):
|
||||
x = x.numpy()
|
||||
if not isinstance(x, np.ndarray):
|
||||
raise ValueError(
|
||||
"Expected a NumPy array, tf.Tensor, Pandas Dataframe or Pandas Series. "
|
||||
f"Received invalid input: {x} (of type {type(x)})"
|
||||
)
|
||||
if not str(x.dtype) == str(dtype):
|
||||
x = x.astype(dtype)
|
||||
return x
|
||||
|
||||
arrays = tf.nest.map_structure(convert_single_array, arrays)
|
||||
|
132
keras_core/trainers/data_adapters/array_data_adapter_test.py
Normal file
132
keras_core/trainers/data_adapters/array_data_adapter_test.py
Normal file
@ -0,0 +1,132 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import testing
|
||||
from keras_core.trainers.data_adapters import array_data_adapter
|
||||
|
||||
|
||||
class TestArrayDataAdapter(testing.TestCase):
|
||||
def _test_basic_flow(self, array_type="np"):
|
||||
if array_type == "np":
|
||||
x = np.random.random((34, 4))
|
||||
y = np.random.random((34, 2))
|
||||
elif array_type == "tf":
|
||||
x = tf.random.normal((34, 4))
|
||||
y = tf.random.normal((34, 2))
|
||||
elif array_type == "pandas":
|
||||
# TODO
|
||||
raise ValueError("TODO")
|
||||
adapter = array_data_adapter.ArrayDataAdapter(
|
||||
x,
|
||||
y=y,
|
||||
sample_weights=None,
|
||||
batch_size=16,
|
||||
steps=None,
|
||||
shuffle=False,
|
||||
)
|
||||
gen = adapter.get_numpy_iterator()
|
||||
for i, batch in enumerate(gen):
|
||||
self.assertEqual(len(batch), 2)
|
||||
bx, by = batch
|
||||
self.assertTrue(isinstance(bx, np.ndarray))
|
||||
self.assertTrue(isinstance(by, np.ndarray))
|
||||
self.assertEqual(bx.dtype, by.dtype)
|
||||
self.assertEqual(bx.dtype, backend.floatx())
|
||||
if i < 2:
|
||||
self.assertEqual(bx.shape, (16, 4))
|
||||
self.assertEqual(by.shape, (16, 2))
|
||||
else:
|
||||
self.assertEqual(bx.shape, (2, 4))
|
||||
self.assertEqual(by.shape, (2, 2))
|
||||
ds = adapter.get_tf_dataset()
|
||||
for i, batch in enumerate(ds):
|
||||
self.assertEqual(len(batch), 2)
|
||||
bx, by = batch
|
||||
self.assertTrue(isinstance(bx, tf.Tensor))
|
||||
self.assertTrue(isinstance(by, tf.Tensor))
|
||||
self.assertEqual(bx.dtype, by.dtype)
|
||||
self.assertEqual(bx.dtype, backend.floatx())
|
||||
if i < 2:
|
||||
self.assertEqual(tuple(bx.shape), (16, 4))
|
||||
self.assertEqual(tuple(by.shape), (16, 2))
|
||||
else:
|
||||
self.assertEqual(tuple(bx.shape), (2, 4))
|
||||
self.assertEqual(tuple(by.shape), (2, 2))
|
||||
|
||||
def test_basic_flow_np(self):
|
||||
self._test_basic_flow("np")
|
||||
|
||||
def test_basic_flow_tf(self):
|
||||
self._test_basic_flow("tf")
|
||||
|
||||
def test_multi_inputs_and_outputs(self):
|
||||
x1 = np.random.random((34, 1))
|
||||
x2 = np.random.random((34, 2))
|
||||
y1 = np.random.random((34, 3))
|
||||
y2 = np.random.random((34, 4))
|
||||
sw = np.random.random((34,))
|
||||
adapter = array_data_adapter.ArrayDataAdapter(
|
||||
x={"x1": x1, "x2": x2},
|
||||
y=[y1, y2],
|
||||
sample_weights=sw,
|
||||
batch_size=16,
|
||||
steps=None,
|
||||
shuffle=False,
|
||||
)
|
||||
gen = adapter.get_numpy_iterator()
|
||||
for i, batch in enumerate(gen):
|
||||
self.assertEqual(len(batch), 3)
|
||||
bx, by, bw = batch
|
||||
self.assertTrue(isinstance(bx, dict))
|
||||
# NOTE: the y list was converted to a tuple for tf.data compatibility.
|
||||
self.assertTrue(isinstance(by, tuple))
|
||||
self.assertTrue(isinstance(bw, np.ndarray))
|
||||
|
||||
self.assertTrue(isinstance(bx["x1"], np.ndarray))
|
||||
self.assertTrue(isinstance(bx["x2"], np.ndarray))
|
||||
self.assertTrue(isinstance(by[0], np.ndarray))
|
||||
self.assertTrue(isinstance(by[1], np.ndarray))
|
||||
|
||||
self.assertEqual(bx["x1"].dtype, by[0].dtype)
|
||||
self.assertEqual(bx["x1"].dtype, backend.floatx())
|
||||
if i < 2:
|
||||
self.assertEqual(bx["x1"].shape, (16, 1))
|
||||
self.assertEqual(bx["x2"].shape, (16, 2))
|
||||
self.assertEqual(by[0].shape, (16, 3))
|
||||
self.assertEqual(by[1].shape, (16, 4))
|
||||
self.assertEqual(bw.shape, (16,))
|
||||
else:
|
||||
self.assertEqual(bx["x1"].shape, (2, 1))
|
||||
self.assertEqual(by[0].shape, (2, 3))
|
||||
self.assertEqual(bw.shape, (2,))
|
||||
ds = adapter.get_tf_dataset()
|
||||
for i, batch in enumerate(ds):
|
||||
self.assertEqual(len(batch), 3)
|
||||
bx, by, bw = batch
|
||||
self.assertTrue(isinstance(bx, dict))
|
||||
# NOTE: the y list was converted to a tuple for tf.data compatibility.
|
||||
self.assertTrue(isinstance(by, tuple))
|
||||
self.assertTrue(isinstance(bw, tf.Tensor))
|
||||
|
||||
self.assertTrue(isinstance(bx["x1"], tf.Tensor))
|
||||
self.assertTrue(isinstance(bx["x2"], tf.Tensor))
|
||||
self.assertTrue(isinstance(by[0], tf.Tensor))
|
||||
self.assertTrue(isinstance(by[1], tf.Tensor))
|
||||
|
||||
self.assertEqual(bx["x1"].dtype, by[0].dtype)
|
||||
self.assertEqual(bx["x1"].dtype, backend.floatx())
|
||||
if i < 2:
|
||||
self.assertEqual(tuple(bx["x1"].shape), (16, 1))
|
||||
self.assertEqual(tuple(bx["x2"].shape), (16, 2))
|
||||
self.assertEqual(tuple(by[0].shape), (16, 3))
|
||||
self.assertEqual(tuple(by[1].shape), (16, 4))
|
||||
self.assertEqual(tuple(bw.shape), (16,))
|
||||
else:
|
||||
self.assertEqual(tuple(bx["x1"].shape), (2, 1))
|
||||
self.assertEqual(tuple(by[0].shape), (2, 3))
|
||||
self.assertEqual(tuple(bw.shape), (2,))
|
||||
|
||||
def test_sample_weights(self):
|
||||
# TODO
|
||||
pass
|
@ -2,177 +2,96 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def handle_partial_sample_weights(
|
||||
outputs, sample_weights, sample_weight_modes, check_all_flat=False
|
||||
):
|
||||
"""Adds 1.0 as sample weights for the outputs for which there is no weight.
|
||||
def unpack_x_y_sample_weight(data):
|
||||
"""Unpacks user-provided data tuple.
|
||||
|
||||
This is a convenience utility to be used when overriding
|
||||
`Model.train_step`, `Model.test_step`, or `Model.predict_step`.
|
||||
This utility makes it easy to support data of the form `(x,)`,
|
||||
`(x, y)`, or `(x, y, sample_weight)`.
|
||||
|
||||
Standalone usage:
|
||||
|
||||
>>> features_batch = tf.ones((10, 5))
|
||||
>>> labels_batch = tf.zeros((10, 5))
|
||||
>>> data = (features_batch, labels_batch)
|
||||
>>> # `y` and `sample_weight` will default to `None` if not provided.
|
||||
>>> x, y, sample_weight = unpack_x_y_sample_weight(data)
|
||||
>>> sample_weight is None
|
||||
True
|
||||
|
||||
Args:
|
||||
outputs: List of model outputs.
|
||||
sample_weights: List of sample weight inputs.
|
||||
sample_weight_modes: List of sample weight modes or None.
|
||||
check_all_flat: Ensure that inputs are not nested structures. This is not
|
||||
a free check, so we may not want to run it eagerly every iteration.
|
||||
data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
|
||||
|
||||
Returns:
|
||||
Tuple of sample weights, one sample weight for every output, and booleans
|
||||
describing the raw sample weights.
|
||||
The unpacked tuple, with `None`s for `y` and `sample_weight` if they are
|
||||
not provided.
|
||||
"""
|
||||
if not isinstance(sample_weights, (list, tuple)):
|
||||
any_sample_weight = sample_weights is not None
|
||||
partial_sample_weight = any_sample_weight and sample_weights is None
|
||||
else:
|
||||
any_sample_weight = sample_weights is not None and any(
|
||||
w is not None for w in sample_weights
|
||||
)
|
||||
partial_sample_weight = any_sample_weight and any(
|
||||
w is None for w in sample_weights
|
||||
)
|
||||
|
||||
if not any_sample_weight:
|
||||
return None, any_sample_weight, partial_sample_weight
|
||||
|
||||
if not partial_sample_weight:
|
||||
return sample_weights, any_sample_weight, partial_sample_weight
|
||||
|
||||
if check_all_flat:
|
||||
tf.nest.assert_same_structure(
|
||||
list_to_tuple(sample_weights),
|
||||
list_to_tuple(tf.nest.flatten(sample_weights)),
|
||||
)
|
||||
tf.nest.assert_same_structure(
|
||||
list_to_tuple(outputs), list_to_tuple(tf.nest.flatten(outputs))
|
||||
)
|
||||
if sample_weight_modes is not None:
|
||||
tf.nest.assert_same_structure(
|
||||
sample_weight_modes, tf.nest.flatten(sample_weight_modes)
|
||||
)
|
||||
|
||||
new_sample_weights = []
|
||||
for i, sw in enumerate(sample_weights):
|
||||
if sw is None:
|
||||
as_numpy = isinstance(outputs[i], np.ndarray)
|
||||
output = outputs[i]
|
||||
output_shape = output.shape if as_numpy else tf.shape(output)
|
||||
|
||||
is_temporal = (
|
||||
sample_weight_modes is not None
|
||||
and sample_weight_modes[i] == "temporal"
|
||||
)
|
||||
sw_shape = (
|
||||
(output_shape[0], output_shape[1])
|
||||
if is_temporal
|
||||
else (output_shape[0],)
|
||||
)
|
||||
|
||||
new_sample_weights.append(
|
||||
np.ones(sw_shape) if as_numpy else tf.ones(sw_shape)
|
||||
)
|
||||
|
||||
else:
|
||||
new_sample_weights.append(sw)
|
||||
return (
|
||||
list_to_tuple(new_sample_weights),
|
||||
any_sample_weight,
|
||||
partial_sample_weight,
|
||||
if isinstance(data, list):
|
||||
data = tuple(data)
|
||||
if not isinstance(data, tuple):
|
||||
return (data, None, None)
|
||||
elif len(data) == 1:
|
||||
return (data[0], None, None)
|
||||
elif len(data) == 2:
|
||||
return (data[0], data[1], None)
|
||||
elif len(data) == 3:
|
||||
return (data[0], data[1], data[2])
|
||||
error_msg = (
|
||||
"Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
|
||||
f"or `(x, y, sample_weight)`, found: {data}"
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def slice_tf_tensors(arrays, indices, contiguous=True):
|
||||
"""Slices batches out of provided arrays (workaround for eager TF tensors).
|
||||
def pack_x_y_sample_weight(x, y=None, sample_weight=None):
|
||||
"""Packs user-provided data into a tuple.
|
||||
|
||||
Unfortunately eager tensors don't have the same slicing behavior as
|
||||
Numpy arrays (they follow the same slicing behavior as symbolic TF tensors),
|
||||
hence we cannot use `generic_utils.slice_arrays` directly
|
||||
and we have to implement this workaround based on `concat`. This has a
|
||||
performance cost.
|
||||
This is a convenience utility for packing data into the tuple formats
|
||||
that `Model.fit` uses.
|
||||
|
||||
Standalone usage:
|
||||
|
||||
>>> x = tf.ones((10, 1))
|
||||
>>> data = pack_x_y_sample_weight(x)
|
||||
>>> isinstance(data, tf.Tensor)
|
||||
True
|
||||
>>> y = tf.ones((10, 1))
|
||||
>>> data = pack_x_y_sample_weight(x, y)
|
||||
>>> isinstance(data, tuple)
|
||||
True
|
||||
>>> x, y = data
|
||||
|
||||
Args:
|
||||
arrays: Single array or list of arrays.
|
||||
indices: List of indices in the array that should be included in the
|
||||
output batch.
|
||||
contiguous: Boolean flag indicating whether the indices are contiguous.
|
||||
x: Features to pass to `Model`.
|
||||
y: Ground-truth targets to pass to `Model`.
|
||||
sample_weight: Sample weight for each element.
|
||||
|
||||
Returns:
|
||||
Slice of data (either single array or list of arrays).
|
||||
Tuple in the format used in `Model.fit`.
|
||||
"""
|
||||
converted_to_list = False
|
||||
if not isinstance(arrays, list):
|
||||
converted_to_list = True
|
||||
arrays = [arrays]
|
||||
if any(tf.is_tensor(x) for x in arrays):
|
||||
if not contiguous:
|
||||
entries = [[x[i : i + 1] for i in indices] for x in arrays]
|
||||
slices = [tf.concat(x, axis=0) for x in entries]
|
||||
if y is None:
|
||||
# For single x-input, we do no tuple wrapping since in this case
|
||||
# there is no ambiguity. This also makes NumPy and Dataset
|
||||
# consistent in that the user does not have to wrap their Dataset
|
||||
# data in an unnecessary tuple.
|
||||
if not isinstance(x, tuple or list):
|
||||
return x
|
||||
else:
|
||||
slices = [x[indices[0] : indices[-1] + 1] for x in arrays]
|
||||
return (x,)
|
||||
elif sample_weight is None:
|
||||
return (x, y)
|
||||
else:
|
||||
slices = slice_arrays(arrays, indices)
|
||||
|
||||
if converted_to_list:
|
||||
slices = slices[0]
|
||||
return slices
|
||||
return (x, y, sample_weight)
|
||||
|
||||
|
||||
def list_to_tuple(maybe_list):
|
||||
"""Datasets will stack the list of tensor, so switch them to tuples."""
|
||||
"""Datasets will stack any list of tensors, so we convert them to tuples."""
|
||||
if isinstance(maybe_list, list):
|
||||
return tuple(maybe_list)
|
||||
return maybe_list
|
||||
|
||||
|
||||
def slice_arrays(arrays, start=None, stop=None):
|
||||
"""Slice an array or list of arrays.
|
||||
|
||||
This takes an array-like, or a list of
|
||||
array-likes, and outputs:
|
||||
- arrays[start:stop] if `arrays` is an array-like
|
||||
- [x[start:stop] for x in arrays] if `arrays` is a list
|
||||
|
||||
Can also work on list/array of indices: `slice_arrays(x, indices)`
|
||||
|
||||
Args:
|
||||
arrays: Single array or list of arrays.
|
||||
start: can be an integer index (start index) or a list/array of indices
|
||||
stop: integer (stop index); should be None if `start` was a list.
|
||||
|
||||
Returns:
|
||||
A slice of the array(s).
|
||||
|
||||
Raises:
|
||||
ValueError: If the value of start is a list and stop is not None.
|
||||
"""
|
||||
if arrays is None:
|
||||
return [None]
|
||||
if isinstance(start, list) and stop is not None:
|
||||
raise ValueError(
|
||||
"The stop argument has to be None if the value of start "
|
||||
f"is a list. Received start={start}, stop={stop}"
|
||||
)
|
||||
elif isinstance(arrays, list):
|
||||
if hasattr(start, "__len__"):
|
||||
# hdf5 datasets only support list objects as indices
|
||||
if hasattr(start, "shape"):
|
||||
start = start.tolist()
|
||||
return [None if x is None else x[start] for x in arrays]
|
||||
return [
|
||||
None
|
||||
if x is None
|
||||
else None
|
||||
if not hasattr(x, "__getitem__")
|
||||
else x[start:stop]
|
||||
for x in arrays
|
||||
]
|
||||
else:
|
||||
if hasattr(start, "__len__"):
|
||||
if hasattr(start, "shape"):
|
||||
start = start.tolist()
|
||||
return arrays[start]
|
||||
if hasattr(start, "__getitem__"):
|
||||
return arrays[start:stop]
|
||||
return [None]
|
||||
|
||||
|
||||
def check_data_cardinality(data):
|
||||
num_samples = set(int(i.shape[0]) for i in tf.nest.flatten(data))
|
||||
if len(num_samples) > 1:
|
||||
|
Loading…
Reference in New Issue
Block a user