keras/keras_core/utils/argument_validation.py
2023-05-17 16:06:01 -07:00

71 lines
2.2 KiB
Python

def standardize_tuple(value, n, name, allow_zero=False):
"""Transforms non-negative/positive integer/integers into an integer tuple.
Args:
value: int or iterable of ints. The value to validate and convert.
n: int. The size of the tuple to be returned.
name: string. The name of the argument being validated, e.g. "strides"
or "kernel_size". This is only used to format error messages.
allow_zero: bool, defaults to False. A ValueError will raised if zero is
received and this param is False.
Returns:
A tuple of n integers.
"""
error_msg = (
f"The `{name}` argument must be a tuple of {n} integers. "
f"Received {name}={value}"
)
if isinstance(value, int):
value_tuple = (value,) * n
else:
try:
value_tuple = tuple(value)
except TypeError:
raise ValueError(error_msg)
if len(value_tuple) != n:
raise ValueError(error_msg)
for single_value in value_tuple:
try:
int(single_value)
except (ValueError, TypeError):
error_msg += (
f"including element {single_value} of "
f"type {type(single_value)}"
)
raise ValueError(error_msg)
if allow_zero:
unqualified_values = {v for v in value_tuple if v < 0}
req_msg = ">= 0"
else:
unqualified_values = {v for v in value_tuple if v <= 0}
req_msg = "> 0"
if unqualified_values:
error_msg += (
f", including values {unqualified_values}"
f" that do not satisfy `value {req_msg}`"
)
raise ValueError(error_msg)
return value_tuple
def standardize_padding(value, allow_causal=False):
if isinstance(value, (list, tuple)):
return value
padding = value.lower()
if allow_causal:
allowed_values = {"valid", "same", "causal"}
else:
allowed_values = {"valid", "same"}
if padding not in allowed_values:
raise ValueError(
"The `padding` argument must be a list/tuple or one of "
f"{allowed_values}. "
f"Received: {padding}"
)
return padding