Docstring nits.

This commit is contained in:
Francois Chollet 2023-07-28 09:54:33 -07:00
parent a49bb48e02
commit 3c428eea92
2 changed files with 41 additions and 66 deletions

@ -10,19 +10,19 @@ import keras_core as keras
keras.config.disable_traceback_filtering() keras.config.disable_traceback_filtering()
inputs = layers.Input((100,)) inputs = layers.Input((100,))
x = layers.Dense(1024, activation="relu")(inputs) x = layers.Dense(512, activation="relu")(inputs)
residual = x residual = x
x = layers.Dense(1024, activation="relu")(x) x = layers.Dense(512, activation="relu")(x)
x = layers.Dense(1024, activation="relu")(x) x = layers.Dense(512, activation="relu")(x)
x += residual x += residual
x = layers.Dense(1024, activation="relu")(x) x = layers.Dense(512, activation="relu")(x)
residual = x residual = x
x = layers.Dense(1024, activation="relu")(x) x = layers.Dense(512, activation="relu")(x)
x = layers.Dense(1024, activation="relu")(x) x = layers.Dense(512, activation="relu")(x)
x += residual x += residual
residual = x residual = x
x = layers.Dense(1024, activation="relu")(x) x = layers.Dense(512, activation="relu")(x)
x = layers.Dense(1024, activation="relu")(x) x = layers.Dense(512, activation="relu")(x)
x += residual x += residual
outputs = layers.Dense(16)(x) outputs = layers.Dense(16)(x)
model = Model(inputs, outputs) model = Model(inputs, outputs)
@ -43,12 +43,12 @@ model.compile(
], ],
) )
print("\nTrain model") # print("\nTrain model")
history = model.fit( # history = model.fit(
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2 # x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
) # )
print("\nHistory:") # print("\nHistory:")
print(history.history) # print(history.history)
print("\nEvaluate model") print("\nEvaluate model")
scores = model.evaluate(x, y, return_dict=True) scores = model.evaluate(x, y, return_dict=True)

@ -12,21 +12,22 @@ from keras_core.utils.module_utils import tensorflow as tf
def split_dataset( def split_dataset(
dataset, left_size=None, right_size=None, shuffle=False, seed=None dataset, left_size=None, right_size=None, shuffle=False, seed=None
): ):
"""Split a dataset into a left half and a right half (e.g. train / test). """Splits a dataset into a left half and a right half (e.g. train / test).
Args: Args:
dataset: dataset:
A `tf.data.Dataset or torch.utils.data.Dataset` object, A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,
or a list/tuple of arrays with the same length. or a list/tuple of arrays with the same length.
left_size: If float (in the range `[0, 1]`), it signifies left_size: If float (in the range `[0, 1]`), it signifies
the fraction of the data to pack in the left dataset. If integer, it the fraction of the data to pack in the left dataset. If integer, it
signifies the number of samples to pack in the left dataset. If signifies the number of samples to pack in the left dataset. If
`None`, it uses the complement to `right_size`. Defaults to `None`. `None`, defaults to the complement to `right_size`.
Defaults to `None`.
right_size: If float (in the range `[0, 1]`), it signifies right_size: If float (in the range `[0, 1]`), it signifies
the fraction of the data to pack in the right dataset. the fraction of the data to pack in the right dataset.
If integer, it signifies the number of samples to pack If integer, it signifies the number of samples to pack
in the right dataset. in the right dataset.
If `None`, it uses the complement to `left_size`. If `None`, defaults to the complement to `left_size`.
Defaults to `None`. Defaults to `None`.
shuffle: Boolean, whether to shuffle the data before splitting it. shuffle: Boolean, whether to shuffle the data before splitting it.
seed: A random seed for shuffling. seed: A random seed for shuffling.
@ -43,15 +44,14 @@ def split_dataset(
800 800
>>> int(right_ds.cardinality()) >>> int(right_ds.cardinality())
200 200
""" """
dataset_type_spec = _get_type_spec(dataset) dataset_type_spec = _get_type_spec(dataset)
if dataset_type_spec is None: if dataset_type_spec is None:
raise TypeError( raise TypeError(
"The `dataset` argument must be either" "The `dataset` argument must be either"
"a `tf.data.Dataset` or `torch.utils.data.Dataset`" "a `tf.data.Dataset`, a `torch.utils.data.Dataset`"
"object or a list/tuple of arrays. " "object, or a list/tuple of arrays. "
f"Received: dataset={dataset} of type {type(dataset)}" f"Received: dataset={dataset} of type {type(dataset)}"
) )
@ -106,23 +106,21 @@ def _convert_dataset_to_list(
data_size_warning_flag=True, data_size_warning_flag=True,
ensure_shape_similarity=True, ensure_shape_similarity=True,
): ):
"""Convert `tf.data.Dataset` or `torch.utils.data.Dataset` object """Convert `dataset` object to a list of samples.
or list/tuple of NumPy arrays to a list.
Args: Args:
dataset : dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,
A `tf.data.Dataset` or `torch.utils.data.Dataset` object
or a list/tuple of arrays. or a list/tuple of arrays.
dataset_type_spec : the type of the dataset dataset_type_spec: the type of the dataset.
data_size_warning_flag (bool, optional): If set to True, a warning will data_size_warning_flag: If set to `True`, a warning will
be issued if the dataset takes longer than 10 seconds to iterate. be issued if the dataset takes longer than 10 seconds to iterate.
Defaults to `True`. Defaults to `True`.
ensure_shape_similarity (bool, optional): If set to True, the shape of ensure_shape_similarity: If set to `True`, the shape of
the first sample will be used to validate the shape of rest of the the first sample will be used to validate the shape of rest of the
samples. Defaults to `True`. samples. Defaults to `True`.
Returns: Returns:
List: A list of tuples/NumPy arrays. List: A list of samples.
""" """
dataset_iterator = _get_data_iterator_from_dataset( dataset_iterator = _get_data_iterator_from_dataset(
dataset, dataset_type_spec dataset, dataset_type_spec
@ -154,19 +152,9 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
"""Get the iterator from a dataset. """Get the iterator from a dataset.
Args: Args:
dataset : dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,
A `tf.data.Dataset` or `torch.utils.data.Dataset` object
or a list/tuple of arrays. or a list/tuple of arrays.
dataset_type_spec : dataset_type_spec: The type of the dataset.
the type of the dataset
Raises:
ValueError:
- If the dataset is empty.
- If the dataset is not a `tf.data.Dataset` object
or a list/tuple of arrays.
- If the dataset is a list/tuple of arrays and the
length of the list/tuple is not equal to the number
Returns: Returns:
iterator: An `iterator` object. iterator: An `iterator` object.
@ -227,13 +215,12 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
if is_batched(dataset): if is_batched(dataset):
dataset = dataset.unbatch() dataset = dataset.unbatch()
return iter(dataset) return iter(dataset)
# torch dataset iterator might be required to change # torch dataset iterator might be required to change
elif is_torch_dataset(dataset): elif is_torch_dataset(dataset):
return iter(dataset) return iter(dataset)
elif dataset_type_spec == np.ndarray: elif dataset_type_spec == np.ndarray:
return iter(dataset) return iter(dataset)
raise ValueError(f"Invalid dataset_type_spec: {dataset_type_spec}")
def _get_next_sample( def _get_next_sample(
@ -242,28 +229,21 @@ def _get_next_sample(
data_size_warning_flag, data_size_warning_flag,
start_time, start_time,
): ):
""" "Yield data samples from the `dataset_iterator`. """Yield data samples from the `dataset_iterator`.
Args: Args:
dataset_iterator : An `iterator` object. dataset_iterator: An `iterator` object.
ensure_shape_similarity (bool, optional): If set to True, the shape of ensure_shape_similarity: If set to `True`, the shape of
the first sample will be used to validate the shape of rest of the the first sample will be used to validate the shape of rest of the
samples. Defaults to `True`. samples. Defaults to `True`.
data_size_warning_flag (bool, optional): If set to True, a warning will data_size_warning_flag: If set to `True`, a warning will
be issued if the dataset takes longer than 10 seconds to iterate. be issued if the dataset takes longer than 10 seconds to iterate.
Defaults to `True`. Defaults to `True`.
start_time (float): the start time of the dataset iteration. this is start_time (float): the start time of the dataset iteration. this is
used only if `data_size_warning_flag` is set to true. used only if `data_size_warning_flag` is set to true.
Raises:
ValueError:
- If the dataset is empty.
- If `ensure_shape_similarity` is set to True and the
shape of the first sample is not equal to the shape of
atleast one of the rest of the samples.
Yields: Yields:
data_sample: A tuple/list of numpy arrays. data_sample: The next sample.
""" """
try: try:
dataset_iterator = iter(dataset_iterator) dataset_iterator = iter(dataset_iterator)
@ -278,7 +258,7 @@ def _get_next_sample(
yield first_sample yield first_sample
except StopIteration: except StopIteration:
raise ValueError( raise ValueError(
"Received an empty Dataset. `dataset` must " "Received an empty dataset. Argument `dataset` must "
"be a non-empty list/tuple of `numpy.ndarray` objects " "be a non-empty list/tuple of `numpy.ndarray` objects "
"or `tf.data.Dataset` objects." "or `tf.data.Dataset` objects."
) )
@ -337,17 +317,12 @@ def _rescale_dataset_split_sizes(left_size, right_size, total_length):
the split sizes is equal to the total length of the dataset. the split sizes is equal to the total length of the dataset.
Args: Args:
left_size : The size of the left dataset split. left_size: The size of the left dataset split.
right_size : The size of the right dataset split. right_size: The size of the right dataset split.
total_length : The total length of the dataset. total_length: The total length of the dataset.
Raises:
TypeError: - If `left_size` or `right_size` is not an integer or float.
ValueError: - If `left_size` or `right_size` is negative or greater
than 1 or greater than `total_length`.
Returns: Returns:
tuple: A tuple of rescaled left_size and right_size tuple: A tuple of rescaled `left_size` and `right_size` integers.
""" """
left_size_type = type(left_size) left_size_type = type(left_size)
right_size_type = type(right_size) right_size_type = type(right_size)
@ -485,7 +460,7 @@ def _restore_dataset_from_list(
def is_batched(dataset): def is_batched(dataset):
""" "Check if the `tf.data.Dataset` is batched.""" """Check if the `tf.data.Dataset` is batched."""
return hasattr(dataset, "_batch_size") return hasattr(dataset, "_batch_size")