Docstring nits.
This commit is contained in:
parent
a49bb48e02
commit
3c428eea92
@ -10,19 +10,19 @@ import keras_core as keras
|
|||||||
keras.config.disable_traceback_filtering()
|
keras.config.disable_traceback_filtering()
|
||||||
|
|
||||||
inputs = layers.Input((100,))
|
inputs = layers.Input((100,))
|
||||||
x = layers.Dense(1024, activation="relu")(inputs)
|
x = layers.Dense(512, activation="relu")(inputs)
|
||||||
residual = x
|
residual = x
|
||||||
x = layers.Dense(1024, activation="relu")(x)
|
x = layers.Dense(512, activation="relu")(x)
|
||||||
x = layers.Dense(1024, activation="relu")(x)
|
x = layers.Dense(512, activation="relu")(x)
|
||||||
x += residual
|
x += residual
|
||||||
x = layers.Dense(1024, activation="relu")(x)
|
x = layers.Dense(512, activation="relu")(x)
|
||||||
residual = x
|
residual = x
|
||||||
x = layers.Dense(1024, activation="relu")(x)
|
x = layers.Dense(512, activation="relu")(x)
|
||||||
x = layers.Dense(1024, activation="relu")(x)
|
x = layers.Dense(512, activation="relu")(x)
|
||||||
x += residual
|
x += residual
|
||||||
residual = x
|
residual = x
|
||||||
x = layers.Dense(1024, activation="relu")(x)
|
x = layers.Dense(512, activation="relu")(x)
|
||||||
x = layers.Dense(1024, activation="relu")(x)
|
x = layers.Dense(512, activation="relu")(x)
|
||||||
x += residual
|
x += residual
|
||||||
outputs = layers.Dense(16)(x)
|
outputs = layers.Dense(16)(x)
|
||||||
model = Model(inputs, outputs)
|
model = Model(inputs, outputs)
|
||||||
@ -43,12 +43,12 @@ model.compile(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\nTrain model")
|
# print("\nTrain model")
|
||||||
history = model.fit(
|
# history = model.fit(
|
||||||
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
|
# x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
|
||||||
)
|
# )
|
||||||
print("\nHistory:")
|
# print("\nHistory:")
|
||||||
print(history.history)
|
# print(history.history)
|
||||||
|
|
||||||
print("\nEvaluate model")
|
print("\nEvaluate model")
|
||||||
scores = model.evaluate(x, y, return_dict=True)
|
scores = model.evaluate(x, y, return_dict=True)
|
||||||
|
@ -12,21 +12,22 @@ from keras_core.utils.module_utils import tensorflow as tf
|
|||||||
def split_dataset(
|
def split_dataset(
|
||||||
dataset, left_size=None, right_size=None, shuffle=False, seed=None
|
dataset, left_size=None, right_size=None, shuffle=False, seed=None
|
||||||
):
|
):
|
||||||
"""Split a dataset into a left half and a right half (e.g. train / test).
|
"""Splits a dataset into a left half and a right half (e.g. train / test).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset:
|
dataset:
|
||||||
A `tf.data.Dataset or torch.utils.data.Dataset` object,
|
A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,
|
||||||
or a list/tuple of arrays with the same length.
|
or a list/tuple of arrays with the same length.
|
||||||
left_size: If float (in the range `[0, 1]`), it signifies
|
left_size: If float (in the range `[0, 1]`), it signifies
|
||||||
the fraction of the data to pack in the left dataset. If integer, it
|
the fraction of the data to pack in the left dataset. If integer, it
|
||||||
signifies the number of samples to pack in the left dataset. If
|
signifies the number of samples to pack in the left dataset. If
|
||||||
`None`, it uses the complement to `right_size`. Defaults to `None`.
|
`None`, defaults to the complement to `right_size`.
|
||||||
|
Defaults to `None`.
|
||||||
right_size: If float (in the range `[0, 1]`), it signifies
|
right_size: If float (in the range `[0, 1]`), it signifies
|
||||||
the fraction of the data to pack in the right dataset.
|
the fraction of the data to pack in the right dataset.
|
||||||
If integer, it signifies the number of samples to pack
|
If integer, it signifies the number of samples to pack
|
||||||
in the right dataset.
|
in the right dataset.
|
||||||
If `None`, it uses the complement to `left_size`.
|
If `None`, defaults to the complement to `left_size`.
|
||||||
Defaults to `None`.
|
Defaults to `None`.
|
||||||
shuffle: Boolean, whether to shuffle the data before splitting it.
|
shuffle: Boolean, whether to shuffle the data before splitting it.
|
||||||
seed: A random seed for shuffling.
|
seed: A random seed for shuffling.
|
||||||
@ -43,15 +44,14 @@ def split_dataset(
|
|||||||
800
|
800
|
||||||
>>> int(right_ds.cardinality())
|
>>> int(right_ds.cardinality())
|
||||||
200
|
200
|
||||||
|
|
||||||
"""
|
"""
|
||||||
dataset_type_spec = _get_type_spec(dataset)
|
dataset_type_spec = _get_type_spec(dataset)
|
||||||
|
|
||||||
if dataset_type_spec is None:
|
if dataset_type_spec is None:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"The `dataset` argument must be either"
|
"The `dataset` argument must be either"
|
||||||
"a `tf.data.Dataset` or `torch.utils.data.Dataset`"
|
"a `tf.data.Dataset`, a `torch.utils.data.Dataset`"
|
||||||
"object or a list/tuple of arrays. "
|
"object, or a list/tuple of arrays. "
|
||||||
f"Received: dataset={dataset} of type {type(dataset)}"
|
f"Received: dataset={dataset} of type {type(dataset)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -106,23 +106,21 @@ def _convert_dataset_to_list(
|
|||||||
data_size_warning_flag=True,
|
data_size_warning_flag=True,
|
||||||
ensure_shape_similarity=True,
|
ensure_shape_similarity=True,
|
||||||
):
|
):
|
||||||
"""Convert `tf.data.Dataset` or `torch.utils.data.Dataset` object
|
"""Convert `dataset` object to a list of samples.
|
||||||
or list/tuple of NumPy arrays to a list.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset :
|
dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,
|
||||||
A `tf.data.Dataset` or `torch.utils.data.Dataset` object
|
|
||||||
or a list/tuple of arrays.
|
or a list/tuple of arrays.
|
||||||
dataset_type_spec : the type of the dataset
|
dataset_type_spec: the type of the dataset.
|
||||||
data_size_warning_flag (bool, optional): If set to True, a warning will
|
data_size_warning_flag: If set to `True`, a warning will
|
||||||
be issued if the dataset takes longer than 10 seconds to iterate.
|
be issued if the dataset takes longer than 10 seconds to iterate.
|
||||||
Defaults to `True`.
|
Defaults to `True`.
|
||||||
ensure_shape_similarity (bool, optional): If set to True, the shape of
|
ensure_shape_similarity: If set to `True`, the shape of
|
||||||
the first sample will be used to validate the shape of rest of the
|
the first sample will be used to validate the shape of rest of the
|
||||||
samples. Defaults to `True`.
|
samples. Defaults to `True`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List: A list of tuples/NumPy arrays.
|
List: A list of samples.
|
||||||
"""
|
"""
|
||||||
dataset_iterator = _get_data_iterator_from_dataset(
|
dataset_iterator = _get_data_iterator_from_dataset(
|
||||||
dataset, dataset_type_spec
|
dataset, dataset_type_spec
|
||||||
@ -154,19 +152,9 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
|
|||||||
"""Get the iterator from a dataset.
|
"""Get the iterator from a dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset :
|
dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,
|
||||||
A `tf.data.Dataset` or `torch.utils.data.Dataset` object
|
|
||||||
or a list/tuple of arrays.
|
or a list/tuple of arrays.
|
||||||
dataset_type_spec :
|
dataset_type_spec: The type of the dataset.
|
||||||
the type of the dataset
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError:
|
|
||||||
- If the dataset is empty.
|
|
||||||
- If the dataset is not a `tf.data.Dataset` object
|
|
||||||
or a list/tuple of arrays.
|
|
||||||
- If the dataset is a list/tuple of arrays and the
|
|
||||||
length of the list/tuple is not equal to the number
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
iterator: An `iterator` object.
|
iterator: An `iterator` object.
|
||||||
@ -227,13 +215,12 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
|
|||||||
if is_batched(dataset):
|
if is_batched(dataset):
|
||||||
dataset = dataset.unbatch()
|
dataset = dataset.unbatch()
|
||||||
return iter(dataset)
|
return iter(dataset)
|
||||||
|
|
||||||
# torch dataset iterator might be required to change
|
# torch dataset iterator might be required to change
|
||||||
elif is_torch_dataset(dataset):
|
elif is_torch_dataset(dataset):
|
||||||
return iter(dataset)
|
return iter(dataset)
|
||||||
|
|
||||||
elif dataset_type_spec == np.ndarray:
|
elif dataset_type_spec == np.ndarray:
|
||||||
return iter(dataset)
|
return iter(dataset)
|
||||||
|
raise ValueError(f"Invalid dataset_type_spec: {dataset_type_spec}")
|
||||||
|
|
||||||
|
|
||||||
def _get_next_sample(
|
def _get_next_sample(
|
||||||
@ -242,28 +229,21 @@ def _get_next_sample(
|
|||||||
data_size_warning_flag,
|
data_size_warning_flag,
|
||||||
start_time,
|
start_time,
|
||||||
):
|
):
|
||||||
""" "Yield data samples from the `dataset_iterator`.
|
"""Yield data samples from the `dataset_iterator`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_iterator : An `iterator` object.
|
dataset_iterator: An `iterator` object.
|
||||||
ensure_shape_similarity (bool, optional): If set to True, the shape of
|
ensure_shape_similarity: If set to `True`, the shape of
|
||||||
the first sample will be used to validate the shape of rest of the
|
the first sample will be used to validate the shape of rest of the
|
||||||
samples. Defaults to `True`.
|
samples. Defaults to `True`.
|
||||||
data_size_warning_flag (bool, optional): If set to True, a warning will
|
data_size_warning_flag: If set to `True`, a warning will
|
||||||
be issued if the dataset takes longer than 10 seconds to iterate.
|
be issued if the dataset takes longer than 10 seconds to iterate.
|
||||||
Defaults to `True`.
|
Defaults to `True`.
|
||||||
start_time (float): the start time of the dataset iteration. this is
|
start_time (float): the start time of the dataset iteration. this is
|
||||||
used only if `data_size_warning_flag` is set to true.
|
used only if `data_size_warning_flag` is set to true.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError:
|
|
||||||
- If the dataset is empty.
|
|
||||||
- If `ensure_shape_similarity` is set to True and the
|
|
||||||
shape of the first sample is not equal to the shape of
|
|
||||||
atleast one of the rest of the samples.
|
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
data_sample: A tuple/list of numpy arrays.
|
data_sample: The next sample.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
dataset_iterator = iter(dataset_iterator)
|
dataset_iterator = iter(dataset_iterator)
|
||||||
@ -278,7 +258,7 @@ def _get_next_sample(
|
|||||||
yield first_sample
|
yield first_sample
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Received an empty Dataset. `dataset` must "
|
"Received an empty dataset. Argument `dataset` must "
|
||||||
"be a non-empty list/tuple of `numpy.ndarray` objects "
|
"be a non-empty list/tuple of `numpy.ndarray` objects "
|
||||||
"or `tf.data.Dataset` objects."
|
"or `tf.data.Dataset` objects."
|
||||||
)
|
)
|
||||||
@ -337,17 +317,12 @@ def _rescale_dataset_split_sizes(left_size, right_size, total_length):
|
|||||||
the split sizes is equal to the total length of the dataset.
|
the split sizes is equal to the total length of the dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
left_size : The size of the left dataset split.
|
left_size: The size of the left dataset split.
|
||||||
right_size : The size of the right dataset split.
|
right_size: The size of the right dataset split.
|
||||||
total_length : The total length of the dataset.
|
total_length: The total length of the dataset.
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: - If `left_size` or `right_size` is not an integer or float.
|
|
||||||
ValueError: - If `left_size` or `right_size` is negative or greater
|
|
||||||
than 1 or greater than `total_length`.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: A tuple of rescaled left_size and right_size
|
tuple: A tuple of rescaled `left_size` and `right_size` integers.
|
||||||
"""
|
"""
|
||||||
left_size_type = type(left_size)
|
left_size_type = type(left_size)
|
||||||
right_size_type = type(right_size)
|
right_size_type = type(right_size)
|
||||||
@ -485,7 +460,7 @@ def _restore_dataset_from_list(
|
|||||||
|
|
||||||
|
|
||||||
def is_batched(dataset):
|
def is_batched(dataset):
|
||||||
""" "Check if the `tf.data.Dataset` is batched."""
|
"""Check if the `tf.data.Dataset` is batched."""
|
||||||
return hasattr(dataset, "_batch_size")
|
return hasattr(dataset, "_batch_size")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user