From 729b60650e9697d261a20785197ec1cb59797c06 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 17 Apr 2023 08:56:30 -0700 Subject: [PATCH] Add tests for array data adapter. --- .../data_adapters/array_data_adapter.py | 43 ++-- .../data_adapters/array_data_adapter_test.py | 132 +++++++++++ .../data_adapters/data_adapters_utils.py | 213 ++++++------------ 3 files changed, 224 insertions(+), 164 deletions(-) create mode 100644 keras_core/trainers/data_adapters/array_data_adapter_test.py diff --git a/keras_core/trainers/data_adapters/array_data_adapter.py b/keras_core/trainers/data_adapters/array_data_adapter.py index 8db0f3781..575922791 100644 --- a/keras_core/trainers/data_adapters/array_data_adapter.py +++ b/keras_core/trainers/data_adapters/array_data_adapter.py @@ -4,6 +4,7 @@ import numpy as np import tensorflow as tf from tensorflow import nest +from keras_core import backend from keras_core.trainers.data_adapters import data_adapters_utils from keras_core.trainers.data_adapters.data_adapter import DataAdapter @@ -13,14 +14,19 @@ except ImportError: pandas = None -ARRAY_TYPES = [tf.Tensor, np.ndarray] +ARRAY_TYPES = (tf.Tensor, np.ndarray) if pandas: - ARRAY_TYPES.extend([tf.Tensor, np.ndarray, pandas.Series, pandas.DataFrame]) + ARRAY_TYPES = ARRAY_TYPES + ( + tf.Tensor, + np.ndarray, + pandas.Series, + pandas.DataFrame, + ) # TODO: support torch tensors? class ArrayDataAdapter(DataAdapter): - """Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy.""" + """Adapter that handles array-like objects, e.g. tf.Tensor and NumPy arrays.""" @staticmethod def can_handle(x, y=None): @@ -41,15 +47,6 @@ class ArrayDataAdapter(DataAdapter): super().__init__(x, y, **kwargs) x, y, sample_weights = convert_to_arrays((x, y, sample_weights)) - # If sample_weights are not specified for an output, use 1.0 as weights. - ( - sample_weights, - _, - _, - ) = data_adapters_utils.handle_partial_sample_weights( - y, sample_weights, check_all_flat=True - ) - inputs = data_adapters_utils.pack_x_y_sample_weight( x, y, sample_weights ) @@ -77,11 +74,11 @@ class ArrayDataAdapter(DataAdapter): start, stop = i * self._batch_size, (i + 1) * self._batch_size yield tf.nest.map_structure(lambda x: x[start:stop], self._inputs) - def get_dataset(self): + def get_tf_dataset(self): ds = tf.data.Dataset.from_tensor_slices(self._inputs) ds = ds.shuffle(self._batch_size * 8) ds = ds.batch(self._batch_size) - ds = ds.preftech(tf.data.AUTOTUNE) + ds = ds.prefetch(tf.data.AUTOTUNE) return ds def get_size(self): @@ -97,7 +94,7 @@ class ArrayDataAdapter(DataAdapter): return self._partial_batch_size or None -def convert_to_arrays(arrays): +def convert_to_arrays(arrays, dtype=None): """Process array-like inputs. This function: @@ -110,15 +107,27 @@ def convert_to_arrays(arrays): inputs: Structure of `Tensor`s, `NumPy` arrays, or tensor-like. Returns: - Structure of `ndarray`. + Structure of NumPy `ndarray`s. """ + dtype = dtype or backend.floatx() def convert_single_array(x): + if x is None: + return x if pandas is not None: if isinstance(x, pandas.Series): - x = np.expand_dims(x.to_numpy(), axis=-1) + x = np.expand_dims(x.to_numpy(dtype=dtype), axis=-1) + elif isinstance(x, pandas.DataFrame): + x = x.to_numpy(dtype=dtype) if isinstance(x, (tf.Tensor, tf.Variable)): x = x.numpy() + if not isinstance(x, np.ndarray): + raise ValueError( + "Expected a NumPy array, tf.Tensor, Pandas Dataframe or Pandas Series. " + f"Received invalid input: {x} (of type {type(x)})" + ) + if not str(x.dtype) == str(dtype): + x = x.astype(dtype) return x arrays = tf.nest.map_structure(convert_single_array, arrays) diff --git a/keras_core/trainers/data_adapters/array_data_adapter_test.py b/keras_core/trainers/data_adapters/array_data_adapter_test.py new file mode 100644 index 000000000..06a1fb76d --- /dev/null +++ b/keras_core/trainers/data_adapters/array_data_adapter_test.py @@ -0,0 +1,132 @@ +import numpy as np +import tensorflow as tf + +from keras_core import backend +from keras_core import testing +from keras_core.trainers.data_adapters import array_data_adapter + + +class TestArrayDataAdapter(testing.TestCase): + def _test_basic_flow(self, array_type="np"): + if array_type == "np": + x = np.random.random((34, 4)) + y = np.random.random((34, 2)) + elif array_type == "tf": + x = tf.random.normal((34, 4)) + y = tf.random.normal((34, 2)) + elif array_type == "pandas": + # TODO + raise ValueError("TODO") + adapter = array_data_adapter.ArrayDataAdapter( + x, + y=y, + sample_weights=None, + batch_size=16, + steps=None, + shuffle=False, + ) + gen = adapter.get_numpy_iterator() + for i, batch in enumerate(gen): + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertTrue(isinstance(bx, np.ndarray)) + self.assertTrue(isinstance(by, np.ndarray)) + self.assertEqual(bx.dtype, by.dtype) + self.assertEqual(bx.dtype, backend.floatx()) + if i < 2: + self.assertEqual(bx.shape, (16, 4)) + self.assertEqual(by.shape, (16, 2)) + else: + self.assertEqual(bx.shape, (2, 4)) + self.assertEqual(by.shape, (2, 2)) + ds = adapter.get_tf_dataset() + for i, batch in enumerate(ds): + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertTrue(isinstance(bx, tf.Tensor)) + self.assertTrue(isinstance(by, tf.Tensor)) + self.assertEqual(bx.dtype, by.dtype) + self.assertEqual(bx.dtype, backend.floatx()) + if i < 2: + self.assertEqual(tuple(bx.shape), (16, 4)) + self.assertEqual(tuple(by.shape), (16, 2)) + else: + self.assertEqual(tuple(bx.shape), (2, 4)) + self.assertEqual(tuple(by.shape), (2, 2)) + + def test_basic_flow_np(self): + self._test_basic_flow("np") + + def test_basic_flow_tf(self): + self._test_basic_flow("tf") + + def test_multi_inputs_and_outputs(self): + x1 = np.random.random((34, 1)) + x2 = np.random.random((34, 2)) + y1 = np.random.random((34, 3)) + y2 = np.random.random((34, 4)) + sw = np.random.random((34,)) + adapter = array_data_adapter.ArrayDataAdapter( + x={"x1": x1, "x2": x2}, + y=[y1, y2], + sample_weights=sw, + batch_size=16, + steps=None, + shuffle=False, + ) + gen = adapter.get_numpy_iterator() + for i, batch in enumerate(gen): + self.assertEqual(len(batch), 3) + bx, by, bw = batch + self.assertTrue(isinstance(bx, dict)) + # NOTE: the y list was converted to a tuple for tf.data compatibility. + self.assertTrue(isinstance(by, tuple)) + self.assertTrue(isinstance(bw, np.ndarray)) + + self.assertTrue(isinstance(bx["x1"], np.ndarray)) + self.assertTrue(isinstance(bx["x2"], np.ndarray)) + self.assertTrue(isinstance(by[0], np.ndarray)) + self.assertTrue(isinstance(by[1], np.ndarray)) + + self.assertEqual(bx["x1"].dtype, by[0].dtype) + self.assertEqual(bx["x1"].dtype, backend.floatx()) + if i < 2: + self.assertEqual(bx["x1"].shape, (16, 1)) + self.assertEqual(bx["x2"].shape, (16, 2)) + self.assertEqual(by[0].shape, (16, 3)) + self.assertEqual(by[1].shape, (16, 4)) + self.assertEqual(bw.shape, (16,)) + else: + self.assertEqual(bx["x1"].shape, (2, 1)) + self.assertEqual(by[0].shape, (2, 3)) + self.assertEqual(bw.shape, (2,)) + ds = adapter.get_tf_dataset() + for i, batch in enumerate(ds): + self.assertEqual(len(batch), 3) + bx, by, bw = batch + self.assertTrue(isinstance(bx, dict)) + # NOTE: the y list was converted to a tuple for tf.data compatibility. + self.assertTrue(isinstance(by, tuple)) + self.assertTrue(isinstance(bw, tf.Tensor)) + + self.assertTrue(isinstance(bx["x1"], tf.Tensor)) + self.assertTrue(isinstance(bx["x2"], tf.Tensor)) + self.assertTrue(isinstance(by[0], tf.Tensor)) + self.assertTrue(isinstance(by[1], tf.Tensor)) + + self.assertEqual(bx["x1"].dtype, by[0].dtype) + self.assertEqual(bx["x1"].dtype, backend.floatx()) + if i < 2: + self.assertEqual(tuple(bx["x1"].shape), (16, 1)) + self.assertEqual(tuple(bx["x2"].shape), (16, 2)) + self.assertEqual(tuple(by[0].shape), (16, 3)) + self.assertEqual(tuple(by[1].shape), (16, 4)) + self.assertEqual(tuple(bw.shape), (16,)) + else: + self.assertEqual(tuple(bx["x1"].shape), (2, 1)) + self.assertEqual(tuple(by[0].shape), (2, 3)) + self.assertEqual(tuple(bw.shape), (2,)) + + def test_sample_weights(self): + # TODO + pass diff --git a/keras_core/trainers/data_adapters/data_adapters_utils.py b/keras_core/trainers/data_adapters/data_adapters_utils.py index 889879e67..5b2e10088 100644 --- a/keras_core/trainers/data_adapters/data_adapters_utils.py +++ b/keras_core/trainers/data_adapters/data_adapters_utils.py @@ -2,177 +2,96 @@ import numpy as np import tensorflow as tf -def handle_partial_sample_weights( - outputs, sample_weights, sample_weight_modes, check_all_flat=False -): - """Adds 1.0 as sample weights for the outputs for which there is no weight. +def unpack_x_y_sample_weight(data): + """Unpacks user-provided data tuple. + + This is a convenience utility to be used when overriding + `Model.train_step`, `Model.test_step`, or `Model.predict_step`. + This utility makes it easy to support data of the form `(x,)`, + `(x, y)`, or `(x, y, sample_weight)`. + + Standalone usage: + + >>> features_batch = tf.ones((10, 5)) + >>> labels_batch = tf.zeros((10, 5)) + >>> data = (features_batch, labels_batch) + >>> # `y` and `sample_weight` will default to `None` if not provided. + >>> x, y, sample_weight = unpack_x_y_sample_weight(data) + >>> sample_weight is None + True Args: - outputs: List of model outputs. - sample_weights: List of sample weight inputs. - sample_weight_modes: List of sample weight modes or None. - check_all_flat: Ensure that inputs are not nested structures. This is not - a free check, so we may not want to run it eagerly every iteration. + data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`. Returns: - Tuple of sample weights, one sample weight for every output, and booleans - describing the raw sample weights. + The unpacked tuple, with `None`s for `y` and `sample_weight` if they are + not provided. """ - if not isinstance(sample_weights, (list, tuple)): - any_sample_weight = sample_weights is not None - partial_sample_weight = any_sample_weight and sample_weights is None - else: - any_sample_weight = sample_weights is not None and any( - w is not None for w in sample_weights - ) - partial_sample_weight = any_sample_weight and any( - w is None for w in sample_weights - ) - - if not any_sample_weight: - return None, any_sample_weight, partial_sample_weight - - if not partial_sample_weight: - return sample_weights, any_sample_weight, partial_sample_weight - - if check_all_flat: - tf.nest.assert_same_structure( - list_to_tuple(sample_weights), - list_to_tuple(tf.nest.flatten(sample_weights)), - ) - tf.nest.assert_same_structure( - list_to_tuple(outputs), list_to_tuple(tf.nest.flatten(outputs)) - ) - if sample_weight_modes is not None: - tf.nest.assert_same_structure( - sample_weight_modes, tf.nest.flatten(sample_weight_modes) - ) - - new_sample_weights = [] - for i, sw in enumerate(sample_weights): - if sw is None: - as_numpy = isinstance(outputs[i], np.ndarray) - output = outputs[i] - output_shape = output.shape if as_numpy else tf.shape(output) - - is_temporal = ( - sample_weight_modes is not None - and sample_weight_modes[i] == "temporal" - ) - sw_shape = ( - (output_shape[0], output_shape[1]) - if is_temporal - else (output_shape[0],) - ) - - new_sample_weights.append( - np.ones(sw_shape) if as_numpy else tf.ones(sw_shape) - ) - - else: - new_sample_weights.append(sw) - return ( - list_to_tuple(new_sample_weights), - any_sample_weight, - partial_sample_weight, + if isinstance(data, list): + data = tuple(data) + if not isinstance(data, tuple): + return (data, None, None) + elif len(data) == 1: + return (data[0], None, None) + elif len(data) == 2: + return (data[0], data[1], None) + elif len(data) == 3: + return (data[0], data[1], data[2]) + error_msg = ( + "Data is expected to be in format `x`, `(x,)`, `(x, y)`, " + f"or `(x, y, sample_weight)`, found: {data}" ) + raise ValueError(error_msg) -def slice_tf_tensors(arrays, indices, contiguous=True): - """Slices batches out of provided arrays (workaround for eager TF tensors). +def pack_x_y_sample_weight(x, y=None, sample_weight=None): + """Packs user-provided data into a tuple. - Unfortunately eager tensors don't have the same slicing behavior as - Numpy arrays (they follow the same slicing behavior as symbolic TF tensors), - hence we cannot use `generic_utils.slice_arrays` directly - and we have to implement this workaround based on `concat`. This has a - performance cost. + This is a convenience utility for packing data into the tuple formats + that `Model.fit` uses. + + Standalone usage: + + >>> x = tf.ones((10, 1)) + >>> data = pack_x_y_sample_weight(x) + >>> isinstance(data, tf.Tensor) + True + >>> y = tf.ones((10, 1)) + >>> data = pack_x_y_sample_weight(x, y) + >>> isinstance(data, tuple) + True + >>> x, y = data Args: - arrays: Single array or list of arrays. - indices: List of indices in the array that should be included in the - output batch. - contiguous: Boolean flag indicating whether the indices are contiguous. + x: Features to pass to `Model`. + y: Ground-truth targets to pass to `Model`. + sample_weight: Sample weight for each element. Returns: - Slice of data (either single array or list of arrays). + Tuple in the format used in `Model.fit`. """ - converted_to_list = False - if not isinstance(arrays, list): - converted_to_list = True - arrays = [arrays] - if any(tf.is_tensor(x) for x in arrays): - if not contiguous: - entries = [[x[i : i + 1] for i in indices] for x in arrays] - slices = [tf.concat(x, axis=0) for x in entries] + if y is None: + # For single x-input, we do no tuple wrapping since in this case + # there is no ambiguity. This also makes NumPy and Dataset + # consistent in that the user does not have to wrap their Dataset + # data in an unnecessary tuple. + if not isinstance(x, tuple or list): + return x else: - slices = [x[indices[0] : indices[-1] + 1] for x in arrays] + return (x,) + elif sample_weight is None: + return (x, y) else: - slices = slice_arrays(arrays, indices) - - if converted_to_list: - slices = slices[0] - return slices + return (x, y, sample_weight) def list_to_tuple(maybe_list): - """Datasets will stack the list of tensor, so switch them to tuples.""" + """Datasets will stack any list of tensors, so we convert them to tuples.""" if isinstance(maybe_list, list): return tuple(maybe_list) return maybe_list -def slice_arrays(arrays, start=None, stop=None): - """Slice an array or list of arrays. - - This takes an array-like, or a list of - array-likes, and outputs: - - arrays[start:stop] if `arrays` is an array-like - - [x[start:stop] for x in arrays] if `arrays` is a list - - Can also work on list/array of indices: `slice_arrays(x, indices)` - - Args: - arrays: Single array or list of arrays. - start: can be an integer index (start index) or a list/array of indices - stop: integer (stop index); should be None if `start` was a list. - - Returns: - A slice of the array(s). - - Raises: - ValueError: If the value of start is a list and stop is not None. - """ - if arrays is None: - return [None] - if isinstance(start, list) and stop is not None: - raise ValueError( - "The stop argument has to be None if the value of start " - f"is a list. Received start={start}, stop={stop}" - ) - elif isinstance(arrays, list): - if hasattr(start, "__len__"): - # hdf5 datasets only support list objects as indices - if hasattr(start, "shape"): - start = start.tolist() - return [None if x is None else x[start] for x in arrays] - return [ - None - if x is None - else None - if not hasattr(x, "__getitem__") - else x[start:stop] - for x in arrays - ] - else: - if hasattr(start, "__len__"): - if hasattr(start, "shape"): - start = start.tolist() - return arrays[start] - if hasattr(start, "__getitem__"): - return arrays[start:stop] - return [None] - - def check_data_cardinality(data): num_samples = set(int(i.shape[0]) for i in tf.nest.flatten(data)) if len(num_samples) > 1: