Docstring nits.
This commit is contained in:
parent
a49bb48e02
commit
3c428eea92
@ -10,19 +10,19 @@ import keras_core as keras
|
||||
keras.config.disable_traceback_filtering()
|
||||
|
||||
inputs = layers.Input((100,))
|
||||
x = layers.Dense(1024, activation="relu")(inputs)
|
||||
x = layers.Dense(512, activation="relu")(inputs)
|
||||
residual = x
|
||||
x = layers.Dense(1024, activation="relu")(x)
|
||||
x = layers.Dense(1024, activation="relu")(x)
|
||||
x = layers.Dense(512, activation="relu")(x)
|
||||
x = layers.Dense(512, activation="relu")(x)
|
||||
x += residual
|
||||
x = layers.Dense(1024, activation="relu")(x)
|
||||
x = layers.Dense(512, activation="relu")(x)
|
||||
residual = x
|
||||
x = layers.Dense(1024, activation="relu")(x)
|
||||
x = layers.Dense(1024, activation="relu")(x)
|
||||
x = layers.Dense(512, activation="relu")(x)
|
||||
x = layers.Dense(512, activation="relu")(x)
|
||||
x += residual
|
||||
residual = x
|
||||
x = layers.Dense(1024, activation="relu")(x)
|
||||
x = layers.Dense(1024, activation="relu")(x)
|
||||
x = layers.Dense(512, activation="relu")(x)
|
||||
x = layers.Dense(512, activation="relu")(x)
|
||||
x += residual
|
||||
outputs = layers.Dense(16)(x)
|
||||
model = Model(inputs, outputs)
|
||||
@ -43,12 +43,12 @@ model.compile(
|
||||
],
|
||||
)
|
||||
|
||||
print("\nTrain model")
|
||||
history = model.fit(
|
||||
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
|
||||
)
|
||||
print("\nHistory:")
|
||||
print(history.history)
|
||||
# print("\nTrain model")
|
||||
# history = model.fit(
|
||||
# x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
|
||||
# )
|
||||
# print("\nHistory:")
|
||||
# print(history.history)
|
||||
|
||||
print("\nEvaluate model")
|
||||
scores = model.evaluate(x, y, return_dict=True)
|
||||
|
@ -12,21 +12,22 @@ from keras_core.utils.module_utils import tensorflow as tf
|
||||
def split_dataset(
|
||||
dataset, left_size=None, right_size=None, shuffle=False, seed=None
|
||||
):
|
||||
"""Split a dataset into a left half and a right half (e.g. train / test).
|
||||
"""Splits a dataset into a left half and a right half (e.g. train / test).
|
||||
|
||||
Args:
|
||||
dataset:
|
||||
A `tf.data.Dataset or torch.utils.data.Dataset` object,
|
||||
A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,
|
||||
or a list/tuple of arrays with the same length.
|
||||
left_size: If float (in the range `[0, 1]`), it signifies
|
||||
the fraction of the data to pack in the left dataset. If integer, it
|
||||
signifies the number of samples to pack in the left dataset. If
|
||||
`None`, it uses the complement to `right_size`. Defaults to `None`.
|
||||
`None`, defaults to the complement to `right_size`.
|
||||
Defaults to `None`.
|
||||
right_size: If float (in the range `[0, 1]`), it signifies
|
||||
the fraction of the data to pack in the right dataset.
|
||||
If integer, it signifies the number of samples to pack
|
||||
in the right dataset.
|
||||
If `None`, it uses the complement to `left_size`.
|
||||
If `None`, defaults to the complement to `left_size`.
|
||||
Defaults to `None`.
|
||||
shuffle: Boolean, whether to shuffle the data before splitting it.
|
||||
seed: A random seed for shuffling.
|
||||
@ -43,15 +44,14 @@ def split_dataset(
|
||||
800
|
||||
>>> int(right_ds.cardinality())
|
||||
200
|
||||
|
||||
"""
|
||||
dataset_type_spec = _get_type_spec(dataset)
|
||||
|
||||
if dataset_type_spec is None:
|
||||
raise TypeError(
|
||||
"The `dataset` argument must be either"
|
||||
"a `tf.data.Dataset` or `torch.utils.data.Dataset`"
|
||||
"object or a list/tuple of arrays. "
|
||||
"a `tf.data.Dataset`, a `torch.utils.data.Dataset`"
|
||||
"object, or a list/tuple of arrays. "
|
||||
f"Received: dataset={dataset} of type {type(dataset)}"
|
||||
)
|
||||
|
||||
@ -106,23 +106,21 @@ def _convert_dataset_to_list(
|
||||
data_size_warning_flag=True,
|
||||
ensure_shape_similarity=True,
|
||||
):
|
||||
"""Convert `tf.data.Dataset` or `torch.utils.data.Dataset` object
|
||||
or list/tuple of NumPy arrays to a list.
|
||||
"""Convert `dataset` object to a list of samples.
|
||||
|
||||
Args:
|
||||
dataset :
|
||||
A `tf.data.Dataset` or `torch.utils.data.Dataset` object
|
||||
dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,
|
||||
or a list/tuple of arrays.
|
||||
dataset_type_spec : the type of the dataset
|
||||
data_size_warning_flag (bool, optional): If set to True, a warning will
|
||||
dataset_type_spec: the type of the dataset.
|
||||
data_size_warning_flag: If set to `True`, a warning will
|
||||
be issued if the dataset takes longer than 10 seconds to iterate.
|
||||
Defaults to `True`.
|
||||
ensure_shape_similarity (bool, optional): If set to True, the shape of
|
||||
ensure_shape_similarity: If set to `True`, the shape of
|
||||
the first sample will be used to validate the shape of rest of the
|
||||
samples. Defaults to `True`.
|
||||
|
||||
Returns:
|
||||
List: A list of tuples/NumPy arrays.
|
||||
List: A list of samples.
|
||||
"""
|
||||
dataset_iterator = _get_data_iterator_from_dataset(
|
||||
dataset, dataset_type_spec
|
||||
@ -154,19 +152,9 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
|
||||
"""Get the iterator from a dataset.
|
||||
|
||||
Args:
|
||||
dataset :
|
||||
A `tf.data.Dataset` or `torch.utils.data.Dataset` object
|
||||
dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,
|
||||
or a list/tuple of arrays.
|
||||
dataset_type_spec :
|
||||
the type of the dataset
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
- If the dataset is empty.
|
||||
- If the dataset is not a `tf.data.Dataset` object
|
||||
or a list/tuple of arrays.
|
||||
- If the dataset is a list/tuple of arrays and the
|
||||
length of the list/tuple is not equal to the number
|
||||
dataset_type_spec: The type of the dataset.
|
||||
|
||||
Returns:
|
||||
iterator: An `iterator` object.
|
||||
@ -227,13 +215,12 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
|
||||
if is_batched(dataset):
|
||||
dataset = dataset.unbatch()
|
||||
return iter(dataset)
|
||||
|
||||
# torch dataset iterator might be required to change
|
||||
elif is_torch_dataset(dataset):
|
||||
return iter(dataset)
|
||||
|
||||
elif dataset_type_spec == np.ndarray:
|
||||
return iter(dataset)
|
||||
raise ValueError(f"Invalid dataset_type_spec: {dataset_type_spec}")
|
||||
|
||||
|
||||
def _get_next_sample(
|
||||
@ -242,28 +229,21 @@ def _get_next_sample(
|
||||
data_size_warning_flag,
|
||||
start_time,
|
||||
):
|
||||
""" "Yield data samples from the `dataset_iterator`.
|
||||
"""Yield data samples from the `dataset_iterator`.
|
||||
|
||||
Args:
|
||||
dataset_iterator : An `iterator` object.
|
||||
ensure_shape_similarity (bool, optional): If set to True, the shape of
|
||||
dataset_iterator: An `iterator` object.
|
||||
ensure_shape_similarity: If set to `True`, the shape of
|
||||
the first sample will be used to validate the shape of rest of the
|
||||
samples. Defaults to `True`.
|
||||
data_size_warning_flag (bool, optional): If set to True, a warning will
|
||||
data_size_warning_flag: If set to `True`, a warning will
|
||||
be issued if the dataset takes longer than 10 seconds to iterate.
|
||||
Defaults to `True`.
|
||||
start_time (float): the start time of the dataset iteration. this is
|
||||
used only if `data_size_warning_flag` is set to true.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
- If the dataset is empty.
|
||||
- If `ensure_shape_similarity` is set to True and the
|
||||
shape of the first sample is not equal to the shape of
|
||||
atleast one of the rest of the samples.
|
||||
|
||||
Yields:
|
||||
data_sample: A tuple/list of numpy arrays.
|
||||
data_sample: The next sample.
|
||||
"""
|
||||
try:
|
||||
dataset_iterator = iter(dataset_iterator)
|
||||
@ -278,7 +258,7 @@ def _get_next_sample(
|
||||
yield first_sample
|
||||
except StopIteration:
|
||||
raise ValueError(
|
||||
"Received an empty Dataset. `dataset` must "
|
||||
"Received an empty dataset. Argument `dataset` must "
|
||||
"be a non-empty list/tuple of `numpy.ndarray` objects "
|
||||
"or `tf.data.Dataset` objects."
|
||||
)
|
||||
@ -337,17 +317,12 @@ def _rescale_dataset_split_sizes(left_size, right_size, total_length):
|
||||
the split sizes is equal to the total length of the dataset.
|
||||
|
||||
Args:
|
||||
left_size : The size of the left dataset split.
|
||||
right_size : The size of the right dataset split.
|
||||
total_length : The total length of the dataset.
|
||||
|
||||
Raises:
|
||||
TypeError: - If `left_size` or `right_size` is not an integer or float.
|
||||
ValueError: - If `left_size` or `right_size` is negative or greater
|
||||
than 1 or greater than `total_length`.
|
||||
left_size: The size of the left dataset split.
|
||||
right_size: The size of the right dataset split.
|
||||
total_length: The total length of the dataset.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple of rescaled left_size and right_size
|
||||
tuple: A tuple of rescaled `left_size` and `right_size` integers.
|
||||
"""
|
||||
left_size_type = type(left_size)
|
||||
right_size_type = type(right_size)
|
||||
@ -485,7 +460,7 @@ def _restore_dataset_from_list(
|
||||
|
||||
|
||||
def is_batched(dataset):
|
||||
""" "Check if the `tf.data.Dataset` is batched."""
|
||||
"""Check if the `tf.data.Dataset` is batched."""
|
||||
return hasattr(dataset, "_batch_size")
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user