Docstring nits.

This commit is contained in:
Francois Chollet 2023-07-28 09:54:33 -07:00
parent a49bb48e02
commit 3c428eea92
2 changed files with 41 additions and 66 deletions

@ -10,19 +10,19 @@ import keras_core as keras
keras.config.disable_traceback_filtering()
inputs = layers.Input((100,))
x = layers.Dense(1024, activation="relu")(inputs)
x = layers.Dense(512, activation="relu")(inputs)
residual = x
x = layers.Dense(1024, activation="relu")(x)
x = layers.Dense(1024, activation="relu")(x)
x = layers.Dense(512, activation="relu")(x)
x = layers.Dense(512, activation="relu")(x)
x += residual
x = layers.Dense(1024, activation="relu")(x)
x = layers.Dense(512, activation="relu")(x)
residual = x
x = layers.Dense(1024, activation="relu")(x)
x = layers.Dense(1024, activation="relu")(x)
x = layers.Dense(512, activation="relu")(x)
x = layers.Dense(512, activation="relu")(x)
x += residual
residual = x
x = layers.Dense(1024, activation="relu")(x)
x = layers.Dense(1024, activation="relu")(x)
x = layers.Dense(512, activation="relu")(x)
x = layers.Dense(512, activation="relu")(x)
x += residual
outputs = layers.Dense(16)(x)
model = Model(inputs, outputs)
@ -43,12 +43,12 @@ model.compile(
],
)
print("\nTrain model")
history = model.fit(
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
)
print("\nHistory:")
print(history.history)
# print("\nTrain model")
# history = model.fit(
# x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
# )
# print("\nHistory:")
# print(history.history)
print("\nEvaluate model")
scores = model.evaluate(x, y, return_dict=True)

@ -12,21 +12,22 @@ from keras_core.utils.module_utils import tensorflow as tf
def split_dataset(
dataset, left_size=None, right_size=None, shuffle=False, seed=None
):
"""Split a dataset into a left half and a right half (e.g. train / test).
"""Splits a dataset into a left half and a right half (e.g. train / test).
Args:
dataset:
A `tf.data.Dataset or torch.utils.data.Dataset` object,
A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,
or a list/tuple of arrays with the same length.
left_size: If float (in the range `[0, 1]`), it signifies
the fraction of the data to pack in the left dataset. If integer, it
signifies the number of samples to pack in the left dataset. If
`None`, it uses the complement to `right_size`. Defaults to `None`.
`None`, defaults to the complement to `right_size`.
Defaults to `None`.
right_size: If float (in the range `[0, 1]`), it signifies
the fraction of the data to pack in the right dataset.
If integer, it signifies the number of samples to pack
in the right dataset.
If `None`, it uses the complement to `left_size`.
If `None`, defaults to the complement to `left_size`.
Defaults to `None`.
shuffle: Boolean, whether to shuffle the data before splitting it.
seed: A random seed for shuffling.
@ -43,15 +44,14 @@ def split_dataset(
800
>>> int(right_ds.cardinality())
200
"""
dataset_type_spec = _get_type_spec(dataset)
if dataset_type_spec is None:
raise TypeError(
"The `dataset` argument must be either"
"a `tf.data.Dataset` or `torch.utils.data.Dataset`"
"object or a list/tuple of arrays. "
"a `tf.data.Dataset`, a `torch.utils.data.Dataset`"
"object, or a list/tuple of arrays. "
f"Received: dataset={dataset} of type {type(dataset)}"
)
@ -106,23 +106,21 @@ def _convert_dataset_to_list(
data_size_warning_flag=True,
ensure_shape_similarity=True,
):
"""Convert `tf.data.Dataset` or `torch.utils.data.Dataset` object
or list/tuple of NumPy arrays to a list.
"""Convert `dataset` object to a list of samples.
Args:
dataset :
A `tf.data.Dataset` or `torch.utils.data.Dataset` object
dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,
or a list/tuple of arrays.
dataset_type_spec : the type of the dataset
data_size_warning_flag (bool, optional): If set to True, a warning will
dataset_type_spec: the type of the dataset.
data_size_warning_flag: If set to `True`, a warning will
be issued if the dataset takes longer than 10 seconds to iterate.
Defaults to `True`.
ensure_shape_similarity (bool, optional): If set to True, the shape of
ensure_shape_similarity: If set to `True`, the shape of
the first sample will be used to validate the shape of rest of the
samples. Defaults to `True`.
Returns:
List: A list of tuples/NumPy arrays.
List: A list of samples.
"""
dataset_iterator = _get_data_iterator_from_dataset(
dataset, dataset_type_spec
@ -154,19 +152,9 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
"""Get the iterator from a dataset.
Args:
dataset :
A `tf.data.Dataset` or `torch.utils.data.Dataset` object
dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,
or a list/tuple of arrays.
dataset_type_spec :
the type of the dataset
Raises:
ValueError:
- If the dataset is empty.
- If the dataset is not a `tf.data.Dataset` object
or a list/tuple of arrays.
- If the dataset is a list/tuple of arrays and the
length of the list/tuple is not equal to the number
dataset_type_spec: The type of the dataset.
Returns:
iterator: An `iterator` object.
@ -227,13 +215,12 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
if is_batched(dataset):
dataset = dataset.unbatch()
return iter(dataset)
# torch dataset iterator might be required to change
elif is_torch_dataset(dataset):
return iter(dataset)
elif dataset_type_spec == np.ndarray:
return iter(dataset)
raise ValueError(f"Invalid dataset_type_spec: {dataset_type_spec}")
def _get_next_sample(
@ -242,28 +229,21 @@ def _get_next_sample(
data_size_warning_flag,
start_time,
):
""" "Yield data samples from the `dataset_iterator`.
"""Yield data samples from the `dataset_iterator`.
Args:
dataset_iterator: An `iterator` object.
ensure_shape_similarity (bool, optional): If set to True, the shape of
ensure_shape_similarity: If set to `True`, the shape of
the first sample will be used to validate the shape of rest of the
samples. Defaults to `True`.
data_size_warning_flag (bool, optional): If set to True, a warning will
data_size_warning_flag: If set to `True`, a warning will
be issued if the dataset takes longer than 10 seconds to iterate.
Defaults to `True`.
start_time (float): the start time of the dataset iteration. this is
used only if `data_size_warning_flag` is set to true.
Raises:
ValueError:
- If the dataset is empty.
- If `ensure_shape_similarity` is set to True and the
shape of the first sample is not equal to the shape of
atleast one of the rest of the samples.
Yields:
data_sample: A tuple/list of numpy arrays.
data_sample: The next sample.
"""
try:
dataset_iterator = iter(dataset_iterator)
@ -278,7 +258,7 @@ def _get_next_sample(
yield first_sample
except StopIteration:
raise ValueError(
"Received an empty Dataset. `dataset` must "
"Received an empty dataset. Argument `dataset` must "
"be a non-empty list/tuple of `numpy.ndarray` objects "
"or `tf.data.Dataset` objects."
)
@ -341,13 +321,8 @@ def _rescale_dataset_split_sizes(left_size, right_size, total_length):
right_size: The size of the right dataset split.
total_length: The total length of the dataset.
Raises:
TypeError: - If `left_size` or `right_size` is not an integer or float.
ValueError: - If `left_size` or `right_size` is negative or greater
than 1 or greater than `total_length`.
Returns:
tuple: A tuple of rescaled left_size and right_size
tuple: A tuple of rescaled `left_size` and `right_size` integers.
"""
left_size_type = type(left_size)
right_size_type = type(right_size)
@ -485,7 +460,7 @@ def _restore_dataset_from_list(
def is_batched(dataset):
""" "Check if the `tf.data.Dataset` is batched."""
"""Check if the `tf.data.Dataset` is batched."""
return hasattr(dataset, "_batch_size")