71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
def standardize_tuple(value, n, name, allow_zero=False):
|
|
"""Transforms non-negative/positive integer/integers into an integer tuple.
|
|
|
|
Args:
|
|
value: int or iterable of ints. The value to validate and convert.
|
|
n: int. The size of the tuple to be returned.
|
|
name: string. The name of the argument being validated, e.g. "strides"
|
|
or "kernel_size". This is only used to format error messages.
|
|
allow_zero: bool, defaults to False. A ValueError will raised if zero is
|
|
received and this param is False.
|
|
|
|
Returns:
|
|
A tuple of n integers.
|
|
"""
|
|
error_msg = (
|
|
f"The `{name}` argument must be a tuple of {n} integers. "
|
|
f"Received {name}={value}"
|
|
)
|
|
|
|
if isinstance(value, int):
|
|
value_tuple = (value,) * n
|
|
else:
|
|
try:
|
|
value_tuple = tuple(value)
|
|
except TypeError:
|
|
raise ValueError(error_msg)
|
|
if len(value_tuple) != n:
|
|
raise ValueError(error_msg)
|
|
for single_value in value_tuple:
|
|
try:
|
|
int(single_value)
|
|
except (ValueError, TypeError):
|
|
error_msg += (
|
|
f"including element {single_value} of "
|
|
f"type {type(single_value)}"
|
|
)
|
|
raise ValueError(error_msg)
|
|
|
|
if allow_zero:
|
|
unqualified_values = {v for v in value_tuple if v < 0}
|
|
req_msg = ">= 0"
|
|
else:
|
|
unqualified_values = {v for v in value_tuple if v <= 0}
|
|
req_msg = "> 0"
|
|
|
|
if unqualified_values:
|
|
error_msg += (
|
|
f", including values {unqualified_values}"
|
|
f" that do not satisfy `value {req_msg}`"
|
|
)
|
|
raise ValueError(error_msg)
|
|
|
|
return value_tuple
|
|
|
|
|
|
def standardize_padding(value, allow_causal=False):
|
|
if isinstance(value, (list, tuple)):
|
|
return value
|
|
padding = value.lower()
|
|
if allow_causal:
|
|
allowed_values = {"valid", "same", "causal"}
|
|
else:
|
|
allowed_values = {"valid", "same"}
|
|
if padding not in allowed_values:
|
|
raise ValueError(
|
|
"The `padding` argument must be a list/tuple or one of "
|
|
f"{allowed_values}. "
|
|
f"Received: {padding}"
|
|
)
|
|
return padding
|