Standardize the validation and defaulting logic for the data_format
argument. (#118)
This commit is contained in:
parent
19e1eabde4
commit
870f6b198f
@ -19,6 +19,7 @@ from keras_core.backend.config import image_data_format
|
||||
from keras_core.backend.config import set_epsilon
|
||||
from keras_core.backend.config import set_floatx
|
||||
from keras_core.backend.config import set_image_data_format
|
||||
from keras_core.backend.config import standardize_data_format
|
||||
from keras_core.utils.io_utils import print_msg
|
||||
|
||||
# Import backend functions.
|
||||
|
@ -152,6 +152,18 @@ def set_image_data_format(data_format):
|
||||
_IMAGE_DATA_FORMAT = str(data_format)
|
||||
|
||||
|
||||
def standardize_data_format(data_format):
|
||||
if data_format is None:
|
||||
return image_data_format()
|
||||
data_format = data_format.lower()
|
||||
if data_format not in {"channels_first", "channels_last"}:
|
||||
raise ValueError(
|
||||
"The `data_format` argument must be one of "
|
||||
f"'channels_first', 'channels_last'. Received: {data_format}"
|
||||
)
|
||||
return data_format
|
||||
|
||||
|
||||
# Set Keras base dir path given KERAS_HOME env variable, if applicable.
|
||||
# Otherwise either ~/.keras or /tmp.
|
||||
if "KERAS_HOME" in os.environ:
|
||||
|
@ -5,7 +5,7 @@ from keras_core import constraints
|
||||
from keras_core import initializers
|
||||
from keras_core import operations as ops
|
||||
from keras_core import regularizers
|
||||
from keras_core.backend import image_data_format
|
||||
from keras_core.backend import standardize_data_format
|
||||
from keras_core.layers.input_spec import InputSpec
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.operations.operation_utils import compute_conv_output_shape
|
||||
@ -117,9 +117,7 @@ class BaseConv(Layer):
|
||||
self.dilation_rate = dilation_rate
|
||||
|
||||
self.padding = padding
|
||||
self.data_format = (
|
||||
image_data_format() if data_format is None else data_format
|
||||
)
|
||||
self.data_format = standardize_data_format(data_format)
|
||||
self.activation = activations.get(activation)
|
||||
self.use_bias = use_bias
|
||||
self.kernel_initializer = initializers.get(kernel_initializer)
|
||||
|
@ -46,7 +46,7 @@ class CenterCrop(Layer):
|
||||
super().__init__(**kwargs)
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.data_format = data_format or backend.image_data_format()
|
||||
self.data_format = backend.standardize_data_format(data_format)
|
||||
|
||||
def call(self, inputs):
|
||||
if self.data_format == "channels_first":
|
||||
|
@ -62,7 +62,7 @@ class Resizing(Layer):
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.interpolation = interpolation
|
||||
self.data_format = data_format or backend.image_data_format()
|
||||
self.data_format = backend.standardize_data_format(data_format)
|
||||
self.crop_to_aspect_ratio = crop_to_aspect_ratio
|
||||
|
||||
def call(self, inputs):
|
||||
|
@ -112,13 +112,7 @@ class SpatialDropout2D(BaseSpatialDropout):
|
||||
self, rate, data_format=None, seed=None, name=None, dtype=None
|
||||
):
|
||||
super().__init__(rate, seed=seed, name=name, dtype=dtype)
|
||||
data_format = data_format or backend.image_data_format()
|
||||
if data_format not in {"channels_last", "channels_first"}:
|
||||
raise ValueError(
|
||||
'`data_format` must be "channels_last" or "channels_first". '
|
||||
f"Received: data_format={data_format}."
|
||||
)
|
||||
self.data_format = data_format
|
||||
self.data_format = backend.standardize_data_format(data_format)
|
||||
self.input_spec = InputSpec(ndim=4)
|
||||
|
||||
def _get_noise_shape(self, inputs):
|
||||
@ -180,13 +174,7 @@ class SpatialDropout3D(BaseSpatialDropout):
|
||||
self, rate, data_format=None, seed=None, name=None, dtype=None
|
||||
):
|
||||
super().__init__(rate, seed=seed, name=name, dtype=dtype)
|
||||
data_format = data_format or backend.image_data_format()
|
||||
if data_format not in {"channels_last", "channels_first"}:
|
||||
raise ValueError(
|
||||
'`data_format` must be "channels_last" or "channels_first". '
|
||||
f"Received: data_format={data_format}."
|
||||
)
|
||||
self.data_format = data_format
|
||||
self.data_format = backend.standardize_data_format(data_format)
|
||||
self.input_spec = InputSpec(ndim=5)
|
||||
|
||||
def _get_noise_shape(self, inputs):
|
||||
|
@ -32,9 +32,7 @@ class Flatten(Layer):
|
||||
|
||||
def __init__(self, data_format=None, name=None, dtype=None):
|
||||
super().__init__(name=name, dtype=dtype)
|
||||
self.data_format = (
|
||||
backend.image_data_format() if data_format is None else data_format
|
||||
)
|
||||
self.data_format = backend.standardize_data_format(data_format)
|
||||
self.input_spec = InputSpec(min_ndim=1)
|
||||
self._channels_first = self.data_format == "channels_first"
|
||||
|
||||
|
@ -66,8 +66,7 @@ def array_to_img(x, data_format=None, scale=True, dtype=None):
|
||||
A PIL Image instance.
|
||||
"""
|
||||
|
||||
if data_format is None:
|
||||
data_format = backend.image_data_format()
|
||||
data_format = backend.standardize_data_format(data_format)
|
||||
if dtype is None:
|
||||
dtype = backend.floatx()
|
||||
if pil_image is None:
|
||||
@ -82,9 +81,6 @@ def array_to_img(x, data_format=None, scale=True, dtype=None):
|
||||
f"Got array with shape: {x.shape}"
|
||||
)
|
||||
|
||||
if data_format not in {"channels_first", "channels_last"}:
|
||||
raise ValueError(f"Invalid data_format: {data_format}")
|
||||
|
||||
# Original NumPy array x has format (height, width, channel)
|
||||
# or (channel, height, width)
|
||||
# but target PIL image has format (width, height, channel)
|
||||
@ -144,12 +140,9 @@ def img_to_array(img, data_format=None, dtype=None):
|
||||
A 3D NumPy array.
|
||||
"""
|
||||
|
||||
if data_format is None:
|
||||
data_format = backend.image_data_format()
|
||||
data_format = backend.standardize_data_format(data_format)
|
||||
if dtype is None:
|
||||
dtype = backend.floatx()
|
||||
if data_format not in {"channels_first", "channels_last"}:
|
||||
raise ValueError(f"Unknown data_format: {data_format}")
|
||||
# NumPy array x has format (height, width, channel)
|
||||
# or (channel, height, width)
|
||||
# but original PIL image has format (width, height, channel)
|
||||
@ -184,8 +177,7 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
|
||||
scale: Whether to rescale image values to be within `[0, 255]`.
|
||||
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
|
||||
"""
|
||||
if data_format is None:
|
||||
data_format = backend.image_data_format()
|
||||
data_format = backend.standardize_data_format(data_format)
|
||||
img = array_to_img(x, data_format=data_format, scale=scale)
|
||||
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):
|
||||
warnings.warn(
|
||||
|
Loading…
Reference in New Issue
Block a user