Finish adding tf.data support in al KPL except Normalization (which is stateful in some cases).

This commit is contained in:
Francois Chollet 2023-06-06 13:58:26 -07:00
parent af3c9a52ae
commit 7e2df47838
28 changed files with 379 additions and 92 deletions

@ -1,10 +1,10 @@
from keras_core import operations as ops from keras_core import operations as ops
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
@keras_core_export("keras_core.layers.CategoryEncoding") @keras_core_export("keras_core.layers.CategoryEncoding")
class CategoryEncoding(Layer): class CategoryEncoding(TFDataLayer):
"""A preprocessing layer which encodes integer features. """A preprocessing layer which encodes integer features.
This layer provides options for condensing data into a categorical encoding This layer provides options for condensing data into a categorical encoding
@ -13,6 +13,9 @@ class CategoryEncoding(Layer):
inputs. For integer inputs where the total number of tokens is not known, inputs. For integer inputs where the total number of tokens is not known,
use `keras_core.layers.IntegerLookup` instead. use `keras_core.layers.IntegerLookup` instead.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Examples: Examples:
**One-hot encoding data** **One-hot encoding data**
@ -130,23 +133,32 @@ class CategoryEncoding(Layer):
if self.output_mode != "count": if self.output_mode != "count":
raise ValueError( raise ValueError(
"`count_weights` is not used when `output_mode` is not " "`count_weights` is not used when `output_mode` is not "
"`'count'`. Received `count_weights={count_weights}`." f"`'count'`. Received `count_weights={count_weights}`."
) )
count_weights = ops.cast(count_weights, self.compute_dtype) count_weights = self.backend.cast(count_weights, self.compute_dtype)
depth = self.num_tokens depth = self.num_tokens
max_value = ops.amax(inputs) max_value = self.backend.numpy.amax(inputs)
min_value = ops.amin(inputs) min_value = self.backend.numpy.amin(inputs)
condition = ops.logical_and( condition = self.backend.numpy.logical_and(
ops.greater(ops.cast(depth, max_value.dtype), max_value), self.backend.numpy.greater(
ops.greater_equal(min_value, ops.cast(0, min_value.dtype)), self.backend.cast(depth, max_value.dtype), max_value
),
self.backend.numpy.greater_equal(
min_value, self.backend.cast(0, min_value.dtype)
),
) )
try:
# Check value range in eager mode only.
condition = bool(condition.__array__())
if not condition: if not condition:
raise ValueError( raise ValueError(
"Input values must be in the range 0 <= values < num_tokens" "Input values must be in the range 0 <= values < num_tokens"
f" with num_tokens={depth}" f" with num_tokens={depth}"
) )
except:
pass
return self._encode_categorical_inputs( return self._encode_categorical_inputs(
inputs, inputs,
@ -164,12 +176,12 @@ class CategoryEncoding(Layer):
): ):
# In all cases, we should uprank scalar input to a single sample. # In all cases, we should uprank scalar input to a single sample.
if len(inputs.shape) == 0: if len(inputs.shape) == 0:
inputs = ops.expand_dims(inputs, -1) inputs = self.backend.numpy.expand_dims(inputs, -1)
# One hot will uprank only if the final output dimension # One hot will uprank only if the final output dimension
# is not already 1. # is not already 1.
if output_mode == "one_hot": if output_mode == "one_hot":
if len(inputs.shape) > 1 and inputs.shape[-1] != 1: if len(inputs.shape) > 1 and inputs.shape[-1] != 1:
inputs = ops.expand_dims(inputs, -1) inputs = self.backend.numpy.expand_dims(inputs, -1)
# TODO(b/190445202): remove output rank restriction. # TODO(b/190445202): remove output rank restriction.
if len(inputs.shape) > 2: if len(inputs.shape) > 2:
@ -181,15 +193,14 @@ class CategoryEncoding(Layer):
) )
binary_output = output_mode in ("multi_hot", "one_hot") binary_output = output_mode in ("multi_hot", "one_hot")
inputs = ops.cast(inputs, "int32") inputs = self.backend.cast(inputs, "int32")
if binary_output: if binary_output:
bincounts = ops.one_hot(inputs, num_classes=depth) bincounts = self.backend.nn.one_hot(inputs, num_classes=depth)
if output_mode == "multi_hot": if output_mode == "multi_hot":
bincounts = ops.sum(bincounts, axis=0) bincounts = self.backend.numpy.sum(bincounts, axis=0)
else: else:
bincounts = ops.bincount( bincounts = self.backend.numpy.bincount(
inputs, minlength=depth, weights=count_weights inputs, minlength=depth, weights=count_weights
) )
return bincounts return bincounts

@ -1,4 +1,5 @@
import numpy as np import numpy as np
import tensorflow as tf
from keras_core import layers from keras_core import layers
from keras_core import testing from keras_core import testing
@ -51,3 +52,19 @@ class CategoryEncodingTest(testing.TestCase):
output_data = layer(input_data) output_data = layer(input_data)
self.assertAllClose(expected_output, output_data) self.assertAllClose(expected_output, output_data)
self.assertEqual(expected_output_shape, output_data.shape) self.assertEqual(expected_output_shape, output_data.shape)
def test_tf_data_compatibility(self):
layer = layers.CategoryEncoding(num_tokens=4, output_mode="one_hot")
input_data = np.array([3, 2, 0, 1])
expected_output = np.array(
[
[0, 0, 0, 1],
[0, 0, 1, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
]
)
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(4).map(layer)
for output in ds.take(1):
output = output.numpy()
self.assertAllClose(output, expected_output)

@ -1,11 +1,11 @@
from keras_core import backend from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
from keras_core.utils import image_utils from keras_core.utils import image_utils
@keras_core_export("keras_core.layers.CenterCrop") @keras_core_export("keras_core.layers.CenterCrop")
class CenterCrop(Layer): class CenterCrop(TFDataLayer):
"""A preprocessing layer which crops images. """A preprocessing layer which crops images.
This layers crops the central portion of the images to a target size. If an This layers crops the central portion of the images to a target size. If an
@ -29,6 +29,9 @@ class CenterCrop(Layer):
If the input height/width is even and the target height/width is odd (or If the input height/width is even and the target height/width is odd (or
inversely), the input image is left-padded by 1 pixel. inversely), the input image is left-padded by 1 pixel.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args: Args:
height: Integer, the height of the output shape. height: Integer, the height of the output shape.
width: Integer, the width of the output shape. width: Integer, the width of the output shape.
@ -99,7 +102,10 @@ class CenterCrop(Layer):
] ]
return image_utils.smart_resize( return image_utils.smart_resize(
inputs, [self.height, self.width], data_format=self.data_format inputs,
[self.height, self.width],
data_format=self.data_format,
backend_module=self.backend,
) )
def compute_output_shape(self, input_shape): def compute_output_shape(self, input_shape):

@ -94,3 +94,11 @@ class CenterCropTest(testing.TestCase, parameterized.TestCase):
size[1], size[1],
)(img) )(img)
self.assertAllClose(ref_out, out) self.assertAllClose(ref_out, out)
def test_tf_data_compatibility(self):
layer = layers.CenterCrop(8, 9)
input_data = np.random.random((2, 10, 12, 3))
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output = output.numpy()
self.assertEqual(list(output.shape), [2, 8, 9, 3])

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.Discretization") @keras_core_export("keras_core.layers.Discretization")
@ -22,6 +23,9 @@ class Discretization(Layer):
with any backend (outside the model itself), which is how we recommend with any backend (outside the model itself), which is how we recommend
to use this layer. to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Input shape: Input shape:
Any array of dimension 2 or higher. Any array of dimension 2 or higher.
@ -192,7 +196,10 @@ class Discretization(Layer):
def call(self, inputs): def call(self, inputs):
outputs = self.layer.call(inputs) outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow" and tf.executing_eagerly(): if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs) outputs = backend.convert_to_tensor(outputs)
return outputs return outputs

@ -3,6 +3,7 @@ import tensorflow as tf
from keras_core import backend from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.HashedCrossing") @keras_core_export("keras_core.layers.HashedCrossing")
@ -25,6 +26,9 @@ class HashedCrossing(Layer):
with any backend (outside the model itself), which is how we recommend with any backend (outside the model itself), which is how we recommend
to use this layer. to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args: Args:
num_bins: Number of hash bins. num_bins: Number of hash bins.
output_mode: Specification for the output of the layer. Values can be output_mode: Specification for the output of the layer. Values can be
@ -100,7 +104,10 @@ class HashedCrossing(Layer):
def call(self, inputs): def call(self, inputs):
outputs = self.layer.call(inputs) outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow" and tf.executing_eagerly(): if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs) outputs = backend.convert_to_tensor(outputs)
return outputs return outputs

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.Hashing") @keras_core_export("keras_core.layers.Hashing")
@ -33,6 +34,9 @@ class Hashing(Layer):
with any backend (outside the model itself), which is how we recommend with any backend (outside the model itself), which is how we recommend
to use this layer. to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
**Example (FarmHash64)** **Example (FarmHash64)**
>>> layer = keras_core.layers.Hashing(num_bins=3) >>> layer = keras_core.layers.Hashing(num_bins=3)
@ -175,7 +179,10 @@ class Hashing(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs)) inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs) outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow" and tf.executing_eagerly(): if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs) outputs = backend.convert_to_tensor(outputs)
return outputs return outputs

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.IntegerLookup") @keras_core_export("keras_core.layers.IntegerLookup")
@ -47,6 +48,9 @@ class IntegerLookup(Layer):
with any backend (outside the model itself), which is how we recommend with any backend (outside the model itself), which is how we recommend
to use this layer. to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args: Args:
max_tokens: Maximum size of the vocabulary for this layer. This should max_tokens: Maximum size of the vocabulary for this layer. This should
only be specified when adapting the vocabulary or when setting only be specified when adapting the vocabulary or when setting
@ -441,7 +445,10 @@ class IntegerLookup(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs)) inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs) outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow" and tf.executing_eagerly(): if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs) outputs = backend.convert_to_tensor(outputs)
return outputs return outputs

@ -1,17 +1,19 @@
from keras_core import operations as ops from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.backend import random from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.RandomBrightness") @keras_core_export("keras_core.layers.RandomBrightness")
class RandomBrightness(Layer): class RandomBrightness(TFDataLayer):
"""A preprocessing layer which randomly adjusts brightness during training. """A preprocessing layer which randomly adjusts brightness during training.
This layer will randomly increase/reduce the brightness for the input RGB This layer will randomly increase/reduce the brightness for the input RGB
images. At inference time, the output will be identical to the input. images. At inference time, the output will be identical to the input.
Call the layer with `training=True` to adjust the brightness of the input. Call the layer with `training=True` to adjust the brightness of the input.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args: Args:
factor: Float or a list/tuple of 2 floats between -1.0 and 1.0. The factor: Float or a list/tuple of 2 floats between -1.0 and 1.0. The
factor is used to determine the lower bound and upper bound of the factor is used to determine the lower bound and upper bound of the
@ -72,21 +74,21 @@ class RandomBrightness(Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self._set_factor(factor) self._set_factor(factor)
self._set_value_range(value_range) self._set_value_range(value_range)
self._seed = seed self.seed = seed
self._generator = random.SeedGenerator(seed) self.generator = backend.random.SeedGenerator(seed)
def _set_value_range(self, value_range): def _set_value_range(self, value_range):
if not isinstance(value_range, (tuple, list)): if not isinstance(value_range, (tuple, list)):
raise ValueError( raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR self.value_range_VALIDATION_ERROR
+ f"Received: value_range={value_range}" + f"Received: value_range={value_range}"
) )
if len(value_range) != 2: if len(value_range) != 2:
raise ValueError( raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR self.value_range_VALIDATION_ERROR
+ f"Received: value_range={value_range}" + f"Received: value_range={value_range}"
) )
self._value_range = sorted(value_range) self.value_range = sorted(value_range)
def _set_factor(self, factor): def _set_factor(self, factor):
if isinstance(factor, (tuple, list)): if isinstance(factor, (tuple, list)):
@ -114,7 +116,7 @@ class RandomBrightness(Layer):
) )
def call(self, inputs, training=True): def call(self, inputs, training=True):
inputs = ops.cast(inputs, self.compute_dtype) inputs = self.backend.cast(inputs, self.compute_dtype)
if training: if training:
return self._brightness_adjust(inputs) return self._brightness_adjust(inputs)
else: else:
@ -127,23 +129,30 @@ class RandomBrightness(Layer):
elif rank == 4: elif rank == 4:
# Keep only the batch dim. This will ensure to have same adjustment # Keep only the batch dim. This will ensure to have same adjustment
# with in one image, but different across the images. # with in one image, but different across the images.
rgb_delta_shape = [images.shape[0], 1, 1, 1] rgb_delta_shape = [self.backend.shape(images)[0], 1, 1, 1]
else: else:
raise ValueError( raise ValueError(
"Expected the input image to be rank 3 or 4. Received " "Expected the input image to be rank 3 or 4. Received "
f"inputs.shape = {images.shape}" f"inputs.shape={images.shape}"
) )
rgb_delta = ops.random.uniform( if backend.backend() != self.backend._backend:
seed_generator = self.backend.random.SeedGenerator(self.seed)
else:
seed_generator = self.generator
rgb_delta = self.backend.random.uniform(
minval=self._factor[0], minval=self._factor[0],
maxval=self._factor[1], maxval=self._factor[1],
shape=rgb_delta_shape, shape=rgb_delta_shape,
seed=self._generator, seed=seed_generator,
) )
rgb_delta = rgb_delta * (self._value_range[1] - self._value_range[0]) rgb_delta = rgb_delta * (self.value_range[1] - self.value_range[0])
rgb_delta = ops.cast(rgb_delta, images.dtype) rgb_delta = self.backend.cast(rgb_delta, images.dtype)
images += rgb_delta images += rgb_delta
return ops.clip(images, self._value_range[0], self._value_range[1]) return self.backend.numpy.clip(
images, self.value_range[0], self.value_range[1]
)
def compute_output_shape(self, input_shape): def compute_output_shape(self, input_shape):
return input_shape return input_shape
@ -151,8 +160,8 @@ class RandomBrightness(Layer):
def get_config(self): def get_config(self):
config = { config = {
"factor": self._factor, "factor": self._factor,
"value_range": self._value_range, "value_range": self.value_range,
"seed": self._seed, "seed": self.seed,
} }
base_config = super().get_config() base_config = super().get_config()
return {**base_config, **config} return {**base_config, **config}

@ -1,4 +1,5 @@
import numpy as np import numpy as np
import tensorflow as tf
from keras_core import layers from keras_core import layers
from keras_core import testing from keras_core import testing
@ -46,3 +47,10 @@ class RandomBrightnessTest(testing.TestCase):
diff = output - inputs diff = output - inputs
self.assertTrue(np.amax(diff) <= 0) self.assertTrue(np.amax(diff) <= 0)
self.assertTrue(np.mean(diff) < 0) self.assertTrue(np.mean(diff) < 0)
def test_tf_data_compatibility(self):
layer = layers.RandomBrightness(factor=0.5, seed=1337)
input_data = np.random.random((2, 8, 8, 3))
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

@ -1,11 +1,10 @@
from keras_core import operations as ops from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.backend import random from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
from keras_core.layers.layer import Layer
@keras_core_export("keras_core.layers.RandomContrast") @keras_core_export("keras_core.layers.RandomContrast")
class RandomContrast(Layer): class RandomContrast(TFDataLayer):
"""A preprocessing layer which randomly adjusts contrast during training. """A preprocessing layer which randomly adjusts contrast during training.
This layer will randomly adjust the contrast of an image or images This layer will randomly adjust the contrast of an image or images
@ -20,6 +19,9 @@ class RandomContrast(Layer):
in integer or floating point dtype. in integer or floating point dtype.
By default, the layer will output floats. By default, the layer will output floats.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Input shape: Input shape:
3D (unbatched) or 4D (batched) tensor with shape: 3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format. `(..., height, width, channels)`, in `"channels_last"` format.
@ -54,30 +56,34 @@ class RandomContrast(Layer):
f"Received: factor={factor}" f"Received: factor={factor}"
) )
self.seed = seed self.seed = seed
self.generator = random.SeedGenerator(seed) self.generator = backend.random.SeedGenerator(seed)
def call(self, inputs, training=True): def call(self, inputs, training=True):
inputs = ops.cast(inputs, self.compute_dtype) inputs = self.backend.cast(inputs, self.compute_dtype)
if training: if training:
factor = ops.random.uniform( if backend.backend() != self.backend._backend:
seed_generator = self.backend.random.SeedGenerator(self.seed)
else:
seed_generator = self.generator
factor = self.backend.random.uniform(
shape=(), shape=(),
minval=1.0 - self.lower, minval=1.0 - self.lower,
maxval=1.0 + self.upper, maxval=1.0 + self.upper,
seed=self.generator, seed=seed_generator,
) )
outputs = self._adjust_constrast(inputs, factor) outputs = self._adjust_constrast(inputs, factor)
outputs = ops.clip(outputs, 0, 255) outputs = self.backend.numpy.clip(outputs, 0, 255)
ops.reshape(outputs, inputs.shape) self.backend.numpy.reshape(outputs, self.backend.shape(inputs))
return outputs return outputs
else: else:
return inputs return inputs
def _adjust_constrast(self, inputs, contrast_factor): def _adjust_constrast(self, inputs, contrast_factor):
# reduce mean on height # reduce mean on height
inp_mean = ops.mean(inputs, axis=-3, keepdims=True) inp_mean = self.backend.numpy.mean(inputs, axis=-3, keepdims=True)
# reduce mean on width # reduce mean on width
inp_mean = ops.mean(inp_mean, axis=-2, keepdims=True) inp_mean = self.backend.numpy.mean(inp_mean, axis=-2, keepdims=True)
outputs = (inputs - inp_mean) * contrast_factor + inp_mean outputs = (inputs - inp_mean) * contrast_factor + inp_mean
return outputs return outputs

@ -1,4 +1,5 @@
import numpy as np import numpy as np
import tensorflow as tf
from keras_core import layers from keras_core import layers
from keras_core import testing from keras_core import testing
@ -33,3 +34,10 @@ class RandomContrastTest(testing.TestCase):
actual_outputs = np.clip(outputs, 0, 255) actual_outputs = np.clip(outputs, 0, 255)
self.assertAllClose(outputs, actual_outputs) self.assertAllClose(outputs, actual_outputs)
def test_tf_data_compatibility(self):
layer = layers.RandomContrast(factor=0.5, seed=1337)
input_data = np.random.random((2, 8, 8, 3))
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.RandomCrop") @keras_core_export("keras_core.layers.RandomCrop")
@ -32,6 +33,9 @@ class RandomCrop(Layer):
with any backend (outside the model itself), which is how we recommend with any backend (outside the model itself), which is how we recommend
to use this layer. to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Input shape: Input shape:
3D (unbatched) or 4D (batched) tensor with shape: 3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format. `(..., height, width, channels)`, in `"channels_last"` format.
@ -59,12 +63,17 @@ class RandomCrop(Layer):
) )
self.supports_masking = False self.supports_masking = False
self.supports_jit = False self.supports_jit = False
self._convert_input_args = False
self._allow_non_tensor_positional_args = True
def call(self, inputs, training=True): def call(self, inputs, training=True):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs)) inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs) outputs = self.layer.call(inputs, training=training)
if backend.backend() != "tensorflow": if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs) outputs = backend.convert_to_tensor(outputs)
return outputs return outputs

@ -1,4 +1,5 @@
import numpy as np import numpy as np
import tensorflow as tf
from keras_core import layers from keras_core import layers
from keras_core import testing from keras_core import testing
@ -63,3 +64,11 @@ class RandomCropTest(testing.TestCase):
supports_masking=False, supports_masking=False,
run_training_check=False, run_training_check=False,
) )
def test_tf_data_compatibility(self):
layer = layers.RandomCrop(8, 9)
input_data = np.random.random((2, 10, 12, 3))
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output = output.numpy()
self.assertEqual(list(output.shape), [2, 8, 9, 3])

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
HORIZONTAL = "horizontal" HORIZONTAL = "horizontal"
VERTICAL = "vertical" VERTICAL = "vertical"
@ -21,6 +22,9 @@ class RandomFlip(Layer):
of integer or floating point dtype. of integer or floating point dtype.
By default, the layer will output floats. By default, the layer will output floats.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Input shape: Input shape:
3D (unbatched) or 4D (batched) tensor with shape: 3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format. `(..., height, width, channels)`, in `"channels_last"` format.
@ -51,12 +55,17 @@ class RandomFlip(Layer):
**kwargs, **kwargs,
) )
self.supports_jit = False self.supports_jit = False
self._convert_input_args = False
self._allow_non_tensor_positional_args = True
def call(self, inputs, training=True): def call(self, inputs, training=True):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs)) inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs) outputs = self.layer.call(inputs, training=training)
if backend.backend() != "tensorflow": if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs) outputs = backend.convert_to_tensor(outputs)
return outputs return outputs

@ -1,4 +1,5 @@
import numpy as np import numpy as np
import tensorflow as tf
from absl.testing import parameterized from absl.testing import parameterized
from keras_core import backend from keras_core import backend
@ -54,3 +55,12 @@ class RandomFlipTest(testing.TestCase, parameterized.TestCase):
supports_masking=False, supports_masking=False,
run_training_check=False, run_training_check=False,
) )
def test_tf_data_compatibility(self):
layer = layers.RandomFlip("vertical", seed=42)
input_data = np.array([[[2, 3, 4]], [[5, 6, 7]]])
expected_output = np.array([[[5, 6, 7]], [[2, 3, 4]]])
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output = output.numpy()
self.assertAllClose(output, expected_output)

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.RandomRotation") @keras_core_export("keras_core.layers.RandomRotation")
@ -29,6 +30,9 @@ class RandomRotation(Layer):
with any backend (outside the model itself), which is how we recommend with any backend (outside the model itself), which is how we recommend
to use this layer. to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Input shape: Input shape:
3D (unbatched) or 4D (batched) tensor with shape: 3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format `(..., height, width, channels)`, in `"channels_last"` format
@ -101,7 +105,10 @@ class RandomRotation(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs)) inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs, training=training) outputs = self.layer.call(inputs, training=training)
if backend.backend() != "tensorflow" and tf.executing_eagerly(): if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs) outputs = backend.convert_to_tensor(outputs)
return outputs return outputs

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.RandomTranslation") @keras_core_export("keras_core.layers.RandomTranslation")
@ -17,6 +18,9 @@ class RandomTranslation(Layer):
of integer or floating point dtype. By default, the layer will output of integer or floating point dtype. By default, the layer will output
floats. floats.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args: Args:
height_factor: a float represented as fraction of value, or a tuple of height_factor: a float represented as fraction of value, or a tuple of
size 2 representing lower and upper bound for shifting vertically. A size 2 representing lower and upper bound for shifting vertically. A
@ -91,7 +95,10 @@ class RandomTranslation(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs)) inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs, training=training) outputs = self.layer.call(inputs, training=training)
if backend.backend() != "tensorflow" and tf.executing_eagerly(): if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs) outputs = backend.convert_to_tensor(outputs)
return outputs return outputs

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.RandomZoom") @keras_core_export("keras_core.layers.RandomZoom")
@ -25,6 +26,9 @@ class RandomZoom(Layer):
with any backend (outside the model itself), which is how we recommend with any backend (outside the model itself), which is how we recommend
to use this layer. to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args: Args:
height_factor: a float represented as fraction of value, height_factor: a float represented as fraction of value,
or a tuple of size 2 representing lower and upper bound or a tuple of size 2 representing lower and upper bound
@ -114,7 +118,10 @@ class RandomZoom(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs)) inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs, training=training) outputs = self.layer.call(inputs, training=training)
if backend.backend() != "tensorflow" and tf.executing_eagerly(): if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs) outputs = backend.convert_to_tensor(outputs)
return outputs return outputs

@ -1,10 +1,9 @@
from keras_core import operations as ops
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
@keras_core_export("keras_core.layers.Rescaling") @keras_core_export("keras_core.layers.Rescaling")
class Rescaling(Layer): class Rescaling(TFDataLayer):
"""A preprocessing layer which rescales input values to a new range. """A preprocessing layer which rescales input values to a new range.
This layer rescales every value of an input (often an image) by multiplying This layer rescales every value of an input (often an image) by multiplying
@ -22,6 +21,9 @@ class Rescaling(Layer):
of integer or floating point dtype, and by default the layer will output of integer or floating point dtype, and by default the layer will output
floats. floats.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args: Args:
scale: Float, the scale to apply to the inputs. scale: Float, the scale to apply to the inputs.
offset: Float, the offset to apply to the inputs. offset: Float, the offset to apply to the inputs.
@ -36,9 +38,9 @@ class Rescaling(Layer):
def call(self, inputs): def call(self, inputs):
dtype = self.compute_dtype dtype = self.compute_dtype
scale = ops.cast(self.scale, dtype) scale = self.backend.cast(self.scale, dtype)
offset = ops.cast(self.offset, dtype) offset = self.backend.cast(self.offset, dtype)
return ops.cast(inputs, dtype) * scale + offset return self.backend.cast(inputs, dtype) * scale + offset
def compute_output_shape(self, input_shape): def compute_output_shape(self, input_shape):
return input_shape return input_shape

@ -1,4 +1,5 @@
import numpy as np import numpy as np
import tensorflow as tf
from keras_core import layers from keras_core import layers
from keras_core import testing from keras_core import testing
@ -62,3 +63,10 @@ class RescalingTest(testing.TestCase):
x = np.random.random((3, 10, 10, 3)) * 255 x = np.random.random((3, 10, 10, 3)) * 255
out = layer(x) out = layer(x)
self.assertAllClose(out, x / 255 + 0.5) self.assertAllClose(out, x / 255 + 0.5)
def test_tf_data_compatibility(self):
layer = layers.Rescaling(scale=1.0 / 255, offset=0.5)
x = np.random.random((3, 10, 10, 3)) * 255
ds = tf.data.Dataset.from_tensor_slices(x).batch(3).map(layer)
for output in ds.take(1):
output.numpy()

@ -1,12 +1,11 @@
from keras_core import backend from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
from keras_core.utils import image_utils from keras_core.utils import image_utils
@keras_core_export("keras_core.layers.Resizing") @keras_core_export("keras_core.layers.Resizing")
class Resizing(Layer): class Resizing(TFDataLayer):
"""A preprocessing layer which resizes images. """A preprocessing layer which resizes images.
This layer resizes an image input to a target height and width. The input This layer resizes an image input to a target height and width. The input
@ -25,6 +24,9 @@ class Resizing(Layer):
or `(..., channels, target_height, target_width)`, or `(..., channels, target_height, target_width)`,
in `"channels_first"` format. in `"channels_first"` format.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args: Args:
height: Integer, the height of the output shape. height: Integer, the height of the output shape.
width: Integer, the width of the output shape. width: Integer, the width of the output shape.
@ -73,9 +75,10 @@ class Resizing(Layer):
size=size, size=size,
interpolation=self.interpolation, interpolation=self.interpolation,
data_format=self.data_format, data_format=self.data_format,
backend_module=self.backend,
) )
else: else:
outputs = ops.image.resize( outputs = self.backend.image.resize(
inputs, inputs,
size=size, size=size,
method=self.interpolation, method=self.interpolation,

@ -142,3 +142,11 @@ class ResizingTest(testing.TestCase, parameterized.TestCase):
size[0], size[1], crop_to_aspect_ratio=crop_to_aspect_ratio size[0], size[1], crop_to_aspect_ratio=crop_to_aspect_ratio
)(img) )(img)
self.assertAllClose(ref_out, out) self.assertAllClose(ref_out, out)
def test_tf_data_compatibility(self):
layer = layers.Resizing(8, 9)
input_data = np.random.random((2, 10, 12, 3))
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output = output.numpy()
self.assertEqual(list(output.shape), [2, 8, 9, 3])

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.StringLookup") @keras_core_export("keras_core.layers.StringLookup")
@ -46,6 +47,9 @@ class StringLookup(Layer):
with any backend (outside the model itself), which is how we recommend with any backend (outside the model itself), which is how we recommend
to use this layer. to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args: Args:
max_tokens: Maximum size of the vocabulary for this layer. This should max_tokens: Maximum size of the vocabulary for this layer. This should
only be specified when adapting the vocabulary or when setting only be specified when adapting the vocabulary or when setting
@ -435,7 +439,10 @@ class StringLookup(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs)) inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs) outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow" and tf.executing_eagerly(): if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs) outputs = backend.convert_to_tensor(outputs)
return outputs return outputs

@ -5,6 +5,7 @@ from keras_core import backend
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer from keras_core.layers.layer import Layer
from keras_core.saving import serialization_lib from keras_core.saving import serialization_lib
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.TextVectorization") @keras_core_export("keras_core.layers.TextVectorization")
@ -63,6 +64,9 @@ class TextVectorization(Layer):
with any backend (outside the model itself), which is how we recommend with any backend (outside the model itself), which is how we recommend
to use this layer. to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args: Args:
max_tokens: Maximum size of the vocabulary for this layer. This should max_tokens: Maximum size of the vocabulary for this layer. This should
only be specified when adapting a vocabulary or when setting only be specified when adapting a vocabulary or when setting
@ -358,7 +362,10 @@ class TextVectorization(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs)) inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs) outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow" and tf.executing_eagerly(): if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs) outputs = backend.convert_to_tensor(outputs)
return outputs return outputs

@ -0,0 +1,36 @@
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
class TFDataLayer(Layer):
"""Layer that can safely used in a tf.data pipeline.
The `call()` method must solely rely on `self.backend` ops.
Only supports a single input tensor argument.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.backend = backend_utils.DynamicBackend()
self._allow_non_tensor_positional_args = True
def __call__(self, inputs, **kwargs):
if backend_utils.in_tf_graph():
# We're in a TF graph, e.g. a tf.data pipeline.
self.backend.set_backend("tensorflow")
inputs = self.backend.convert_to_tensor(
inputs, dtype=self.compute_dtype
)
switch_convert_input_args = False
if self._convert_input_args:
self._convert_input_args = False
switch_convert_input_args = True
try:
outputs = super().__call__(inputs, **kwargs)
finally:
self.backend.reset()
if switch_convert_input_args:
self._convert_input_args = True
return outputs
return super().__call__(inputs, **kwargs)

@ -0,0 +1,48 @@
import sys
from keras_core import backend as backend_module
from keras_core.backend import jax as jax_backend
from keras_core.backend import tensorflow as tf_backend
from keras_core.backend import torch as torch_backend
def in_tf_graph():
if "tensorflow" in sys.modules:
import tensorflow as tf
return not tf.executing_eagerly()
return False
class DynamicBackend:
"""A class that can be used to switch from one backend to another.
Usage:
```python
backend = DynamicBackend("tensorflow")
y = backend.square(tf.constant(...))
backend.set_backend("jax")
y = backend.square(jax.numpy.array(...))
```
Args:
backend: Initial backend to use (string).
"""
def __init__(self, backend=None):
self._backend = backend or backend_module.backend()
def set_backend(self, backend):
self._backend = backend
def reset(self):
self._backend = backend_module.backend()
def __getattr__(self, name):
if self._backend == "tensorflow":
return getattr(tf_backend, name)
if self._backend == "jax":
return getattr(jax_backend, name)
if self._backend == "torch":
return getattr(torch_backend, name)

@ -299,7 +299,11 @@ def load_img(
def smart_resize( def smart_resize(
x, size, interpolation="bilinear", data_format="channels_last" x,
size,
interpolation="bilinear",
data_format="channels_last",
backend_module=None,
): ):
"""Resize images to a target size without aspect ratio distortion. """Resize images to a target size without aspect ratio distortion.
@ -354,6 +358,8 @@ def smart_resize(
Supports `bilinear`, `nearest`, `bicubic`, Supports `bilinear`, `nearest`, `bicubic`,
`lanczos3`, `lanczos5`. `lanczos3`, `lanczos5`.
data_format: `"channels_last"` or `"channels_first"`. data_format: `"channels_last"` or `"channels_first"`.
backend_module: Backend module to use (if different from the default
backend).
Returns: Returns:
Array with shape `(size[0], size[1], channels)`. Array with shape `(size[0], size[1], channels)`.
@ -361,11 +367,12 @@ def smart_resize(
and if it was a backend-native tensor, and if it was a backend-native tensor,
the output is a backend-native tensor. the output is a backend-native tensor.
""" """
backend_module = backend_module or backend
if len(size) != 2: if len(size) != 2:
raise ValueError( raise ValueError(
f"Expected `size` to be a tuple of 2 integers, but got: {size}." f"Expected `size` to be a tuple of 2 integers, but got: {size}."
) )
img = backend.convert_to_tensor(x) img = backend_module.convert_to_tensor(x)
if len(img.shape) is not None: if len(img.shape) is not None:
if len(img.shape) < 3 or len(img.shape) > 4: if len(img.shape) < 3 or len(img.shape) > 4:
raise ValueError( raise ValueError(
@ -373,30 +380,32 @@ def smart_resize(
"channels)`, or `(batch_size, height, width, channels)`, but " "channels)`, or `(batch_size, height, width, channels)`, but "
f"got input with incorrect rank, of shape {img.shape}." f"got input with incorrect rank, of shape {img.shape}."
) )
shape = ops.shape(img) shape = backend_module.shape(img)
if data_format == "channels_last": if data_format == "channels_last":
height, width = shape[-3], shape[-2] height, width = shape[-3], shape[-2]
else: else:
height, width = shape[-2], shape[-1] height, width = shape[-2], shape[-1]
target_height, target_width = size target_height, target_width = size
crop_height = ops.cast( crop_height = backend_module.cast(
ops.cast(width * target_height, "float32") / target_width, "int32" backend_module.cast(width * target_height, "float32") / target_width,
"int32",
) )
crop_width = ops.cast( crop_width = backend_module.cast(
ops.cast(height * target_width, "float32") / target_height, "int32" backend_module.cast(height * target_width, "float32") / target_height,
"int32",
) )
# Set back to input height / width if crop_height / crop_width is not # Set back to input height / width if crop_height / crop_width is not
# smaller. # smaller.
crop_height = ops.minimum(height, crop_height) crop_height = backend_module.numpy.minimum(height, crop_height)
crop_width = ops.minimum(width, crop_width) crop_width = backend_module.numpy.minimum(width, crop_width)
crop_box_hstart = ops.cast( crop_box_hstart = backend_module.cast(
ops.cast(height - crop_height, "float32") / 2, "int32" backend_module.cast(height - crop_height, "float32") / 2, "int32"
) )
crop_box_wstart = ops.cast( crop_box_wstart = backend_module.cast(
ops.cast(width - crop_width, "float32") / 2, "int32" backend_module.cast(width - crop_width, "float32") / 2, "int32"
) )
if data_format == "channels_last": if data_format == "channels_last":
@ -428,7 +437,7 @@ def smart_resize(
crop_box_wstart : crop_box_wstart + crop_width, crop_box_wstart : crop_box_wstart + crop_width,
] ]
img = ops.image.resize( img = backend_module.image.resize(
img, size=size, method=interpolation, data_format=data_format img, size=size, method=interpolation, data_format=data_format
) )