From 3c428eea924ac412e6bdf7f433720ab23e69b299 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Fri, 28 Jul 2023 09:54:33 -0700 Subject: [PATCH] Docstring nits. --- examples/demo_functional.py | 28 +++++------ keras_core/utils/dataset_utils.py | 79 +++++++++++-------------------- 2 files changed, 41 insertions(+), 66 deletions(-) diff --git a/examples/demo_functional.py b/examples/demo_functional.py index 930130113..bf09bfaff 100644 --- a/examples/demo_functional.py +++ b/examples/demo_functional.py @@ -10,19 +10,19 @@ import keras_core as keras keras.config.disable_traceback_filtering() inputs = layers.Input((100,)) -x = layers.Dense(1024, activation="relu")(inputs) +x = layers.Dense(512, activation="relu")(inputs) residual = x -x = layers.Dense(1024, activation="relu")(x) -x = layers.Dense(1024, activation="relu")(x) +x = layers.Dense(512, activation="relu")(x) +x = layers.Dense(512, activation="relu")(x) x += residual -x = layers.Dense(1024, activation="relu")(x) +x = layers.Dense(512, activation="relu")(x) residual = x -x = layers.Dense(1024, activation="relu")(x) -x = layers.Dense(1024, activation="relu")(x) +x = layers.Dense(512, activation="relu")(x) +x = layers.Dense(512, activation="relu")(x) x += residual residual = x -x = layers.Dense(1024, activation="relu")(x) -x = layers.Dense(1024, activation="relu")(x) +x = layers.Dense(512, activation="relu")(x) +x = layers.Dense(512, activation="relu")(x) x += residual outputs = layers.Dense(16)(x) model = Model(inputs, outputs) @@ -43,12 +43,12 @@ model.compile( ], ) -print("\nTrain model") -history = model.fit( - x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2 -) -print("\nHistory:") -print(history.history) +# print("\nTrain model") +# history = model.fit( +# x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2 +# ) +# print("\nHistory:") +# print(history.history) print("\nEvaluate model") scores = model.evaluate(x, y, return_dict=True) diff --git a/keras_core/utils/dataset_utils.py b/keras_core/utils/dataset_utils.py index 7cc5bfb8c..b8b9506d7 100644 --- a/keras_core/utils/dataset_utils.py +++ b/keras_core/utils/dataset_utils.py @@ -12,21 +12,22 @@ from keras_core.utils.module_utils import tensorflow as tf def split_dataset( dataset, left_size=None, right_size=None, shuffle=False, seed=None ): - """Split a dataset into a left half and a right half (e.g. train / test). + """Splits a dataset into a left half and a right half (e.g. train / test). Args: dataset: - A `tf.data.Dataset or torch.utils.data.Dataset` object, + A `tf.data.Dataset`, a `torch.utils.data.Dataset` object, or a list/tuple of arrays with the same length. left_size: If float (in the range `[0, 1]`), it signifies the fraction of the data to pack in the left dataset. If integer, it signifies the number of samples to pack in the left dataset. If - `None`, it uses the complement to `right_size`. Defaults to `None`. + `None`, defaults to the complement to `right_size`. + Defaults to `None`. right_size: If float (in the range `[0, 1]`), it signifies the fraction of the data to pack in the right dataset. If integer, it signifies the number of samples to pack in the right dataset. - If `None`, it uses the complement to `left_size`. + If `None`, defaults to the complement to `left_size`. Defaults to `None`. shuffle: Boolean, whether to shuffle the data before splitting it. seed: A random seed for shuffling. @@ -43,15 +44,14 @@ def split_dataset( 800 >>> int(right_ds.cardinality()) 200 - """ dataset_type_spec = _get_type_spec(dataset) if dataset_type_spec is None: raise TypeError( "The `dataset` argument must be either" - "a `tf.data.Dataset` or `torch.utils.data.Dataset`" - "object or a list/tuple of arrays. " + "a `tf.data.Dataset`, a `torch.utils.data.Dataset`" + "object, or a list/tuple of arrays. " f"Received: dataset={dataset} of type {type(dataset)}" ) @@ -106,23 +106,21 @@ def _convert_dataset_to_list( data_size_warning_flag=True, ensure_shape_similarity=True, ): - """Convert `tf.data.Dataset` or `torch.utils.data.Dataset` object - or list/tuple of NumPy arrays to a list. + """Convert `dataset` object to a list of samples. Args: - dataset : - A `tf.data.Dataset` or `torch.utils.data.Dataset` object + dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object, or a list/tuple of arrays. - dataset_type_spec : the type of the dataset - data_size_warning_flag (bool, optional): If set to True, a warning will + dataset_type_spec: the type of the dataset. + data_size_warning_flag: If set to `True`, a warning will be issued if the dataset takes longer than 10 seconds to iterate. Defaults to `True`. - ensure_shape_similarity (bool, optional): If set to True, the shape of + ensure_shape_similarity: If set to `True`, the shape of the first sample will be used to validate the shape of rest of the samples. Defaults to `True`. Returns: - List: A list of tuples/NumPy arrays. + List: A list of samples. """ dataset_iterator = _get_data_iterator_from_dataset( dataset, dataset_type_spec @@ -154,19 +152,9 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec): """Get the iterator from a dataset. Args: - dataset : - A `tf.data.Dataset` or `torch.utils.data.Dataset` object + dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object, or a list/tuple of arrays. - dataset_type_spec : - the type of the dataset - - Raises: - ValueError: - - If the dataset is empty. - - If the dataset is not a `tf.data.Dataset` object - or a list/tuple of arrays. - - If the dataset is a list/tuple of arrays and the - length of the list/tuple is not equal to the number + dataset_type_spec: The type of the dataset. Returns: iterator: An `iterator` object. @@ -227,13 +215,12 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec): if is_batched(dataset): dataset = dataset.unbatch() return iter(dataset) - # torch dataset iterator might be required to change elif is_torch_dataset(dataset): return iter(dataset) - elif dataset_type_spec == np.ndarray: return iter(dataset) + raise ValueError(f"Invalid dataset_type_spec: {dataset_type_spec}") def _get_next_sample( @@ -242,28 +229,21 @@ def _get_next_sample( data_size_warning_flag, start_time, ): - """ "Yield data samples from the `dataset_iterator`. + """Yield data samples from the `dataset_iterator`. Args: - dataset_iterator : An `iterator` object. - ensure_shape_similarity (bool, optional): If set to True, the shape of + dataset_iterator: An `iterator` object. + ensure_shape_similarity: If set to `True`, the shape of the first sample will be used to validate the shape of rest of the samples. Defaults to `True`. - data_size_warning_flag (bool, optional): If set to True, a warning will + data_size_warning_flag: If set to `True`, a warning will be issued if the dataset takes longer than 10 seconds to iterate. Defaults to `True`. start_time (float): the start time of the dataset iteration. this is used only if `data_size_warning_flag` is set to true. - Raises: - ValueError: - - If the dataset is empty. - - If `ensure_shape_similarity` is set to True and the - shape of the first sample is not equal to the shape of - atleast one of the rest of the samples. - Yields: - data_sample: A tuple/list of numpy arrays. + data_sample: The next sample. """ try: dataset_iterator = iter(dataset_iterator) @@ -278,7 +258,7 @@ def _get_next_sample( yield first_sample except StopIteration: raise ValueError( - "Received an empty Dataset. `dataset` must " + "Received an empty dataset. Argument `dataset` must " "be a non-empty list/tuple of `numpy.ndarray` objects " "or `tf.data.Dataset` objects." ) @@ -337,17 +317,12 @@ def _rescale_dataset_split_sizes(left_size, right_size, total_length): the split sizes is equal to the total length of the dataset. Args: - left_size : The size of the left dataset split. - right_size : The size of the right dataset split. - total_length : The total length of the dataset. - - Raises: - TypeError: - If `left_size` or `right_size` is not an integer or float. - ValueError: - If `left_size` or `right_size` is negative or greater - than 1 or greater than `total_length`. + left_size: The size of the left dataset split. + right_size: The size of the right dataset split. + total_length: The total length of the dataset. Returns: - tuple: A tuple of rescaled left_size and right_size + tuple: A tuple of rescaled `left_size` and `right_size` integers. """ left_size_type = type(left_size) right_size_type = type(right_size) @@ -485,7 +460,7 @@ def _restore_dataset_from_list( def is_batched(dataset): - """ "Check if the `tf.data.Dataset` is batched.""" + """Check if the `tf.data.Dataset` is batched.""" return hasattr(dataset, "_batch_size")