Standardize the validation and defaulting logic for the data_format argument. (#118)

This commit is contained in:
hertschuh 2023-05-09 11:22:21 -07:00 committed by Francois Chollet
parent 19e1eabde4
commit 870f6b198f
8 changed files with 23 additions and 34 deletions

@ -19,6 +19,7 @@ from keras_core.backend.config import image_data_format
from keras_core.backend.config import set_epsilon
from keras_core.backend.config import set_floatx
from keras_core.backend.config import set_image_data_format
from keras_core.backend.config import standardize_data_format
from keras_core.utils.io_utils import print_msg
# Import backend functions.

@ -152,6 +152,18 @@ def set_image_data_format(data_format):
_IMAGE_DATA_FORMAT = str(data_format)
def standardize_data_format(data_format):
if data_format is None:
return image_data_format()
data_format = data_format.lower()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(
"The `data_format` argument must be one of "
f"'channels_first', 'channels_last'. Received: {data_format}"
)
return data_format
# Set Keras base dir path given KERAS_HOME env variable, if applicable.
# Otherwise either ~/.keras or /tmp.
if "KERAS_HOME" in os.environ:

@ -5,7 +5,7 @@ from keras_core import constraints
from keras_core import initializers
from keras_core import operations as ops
from keras_core import regularizers
from keras_core.backend import image_data_format
from keras_core.backend import standardize_data_format
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.operations.operation_utils import compute_conv_output_shape
@ -117,9 +117,7 @@ class BaseConv(Layer):
self.dilation_rate = dilation_rate
self.padding = padding
self.data_format = (
image_data_format() if data_format is None else data_format
)
self.data_format = standardize_data_format(data_format)
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)

@ -46,7 +46,7 @@ class CenterCrop(Layer):
super().__init__(**kwargs)
self.height = height
self.width = width
self.data_format = data_format or backend.image_data_format()
self.data_format = backend.standardize_data_format(data_format)
def call(self, inputs):
if self.data_format == "channels_first":

@ -62,7 +62,7 @@ class Resizing(Layer):
self.height = height
self.width = width
self.interpolation = interpolation
self.data_format = data_format or backend.image_data_format()
self.data_format = backend.standardize_data_format(data_format)
self.crop_to_aspect_ratio = crop_to_aspect_ratio
def call(self, inputs):

@ -112,13 +112,7 @@ class SpatialDropout2D(BaseSpatialDropout):
self, rate, data_format=None, seed=None, name=None, dtype=None
):
super().__init__(rate, seed=seed, name=name, dtype=dtype)
data_format = data_format or backend.image_data_format()
if data_format not in {"channels_last", "channels_first"}:
raise ValueError(
'`data_format` must be "channels_last" or "channels_first". '
f"Received: data_format={data_format}."
)
self.data_format = data_format
self.data_format = backend.standardize_data_format(data_format)
self.input_spec = InputSpec(ndim=4)
def _get_noise_shape(self, inputs):
@ -180,13 +174,7 @@ class SpatialDropout3D(BaseSpatialDropout):
self, rate, data_format=None, seed=None, name=None, dtype=None
):
super().__init__(rate, seed=seed, name=name, dtype=dtype)
data_format = data_format or backend.image_data_format()
if data_format not in {"channels_last", "channels_first"}:
raise ValueError(
'`data_format` must be "channels_last" or "channels_first". '
f"Received: data_format={data_format}."
)
self.data_format = data_format
self.data_format = backend.standardize_data_format(data_format)
self.input_spec = InputSpec(ndim=5)
def _get_noise_shape(self, inputs):

@ -32,9 +32,7 @@ class Flatten(Layer):
def __init__(self, data_format=None, name=None, dtype=None):
super().__init__(name=name, dtype=dtype)
self.data_format = (
backend.image_data_format() if data_format is None else data_format
)
self.data_format = backend.standardize_data_format(data_format)
self.input_spec = InputSpec(min_ndim=1)
self._channels_first = self.data_format == "channels_first"

@ -66,8 +66,7 @@ def array_to_img(x, data_format=None, scale=True, dtype=None):
A PIL Image instance.
"""
if data_format is None:
data_format = backend.image_data_format()
data_format = backend.standardize_data_format(data_format)
if dtype is None:
dtype = backend.floatx()
if pil_image is None:
@ -82,9 +81,6 @@ def array_to_img(x, data_format=None, scale=True, dtype=None):
f"Got array with shape: {x.shape}"
)
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Invalid data_format: {data_format}")
# Original NumPy array x has format (height, width, channel)
# or (channel, height, width)
# but target PIL image has format (width, height, channel)
@ -144,12 +140,9 @@ def img_to_array(img, data_format=None, dtype=None):
A 3D NumPy array.
"""
if data_format is None:
data_format = backend.image_data_format()
data_format = backend.standardize_data_format(data_format)
if dtype is None:
dtype = backend.floatx()
if data_format not in {"channels_first", "channels_last"}:
raise ValueError(f"Unknown data_format: {data_format}")
# NumPy array x has format (height, width, channel)
# or (channel, height, width)
# but original PIL image has format (width, height, channel)
@ -184,8 +177,7 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
scale: Whether to rescale image values to be within `[0, 255]`.
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
"""
if data_format is None:
data_format = backend.image_data_format()
data_format = backend.standardize_data_format(data_format)
img = array_to_img(x, data_format=data_format, scale=scale)
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):
warnings.warn(