Finish adding tf.data support in al KPL except Normalization (which is stateful in some cases).

This commit is contained in:
Francois Chollet 2023-06-06 13:58:26 -07:00
parent af3c9a52ae
commit 7e2df47838
28 changed files with 379 additions and 92 deletions

@ -1,10 +1,10 @@
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
@keras_core_export("keras_core.layers.CategoryEncoding")
class CategoryEncoding(Layer):
class CategoryEncoding(TFDataLayer):
"""A preprocessing layer which encodes integer features.
This layer provides options for condensing data into a categorical encoding
@ -13,6 +13,9 @@ class CategoryEncoding(Layer):
inputs. For integer inputs where the total number of tokens is not known,
use `keras_core.layers.IntegerLookup` instead.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Examples:
**One-hot encoding data**
@ -130,23 +133,32 @@ class CategoryEncoding(Layer):
if self.output_mode != "count":
raise ValueError(
"`count_weights` is not used when `output_mode` is not "
"`'count'`. Received `count_weights={count_weights}`."
f"`'count'`. Received `count_weights={count_weights}`."
)
count_weights = ops.cast(count_weights, self.compute_dtype)
count_weights = self.backend.cast(count_weights, self.compute_dtype)
depth = self.num_tokens
max_value = ops.amax(inputs)
min_value = ops.amin(inputs)
condition = ops.logical_and(
ops.greater(ops.cast(depth, max_value.dtype), max_value),
ops.greater_equal(min_value, ops.cast(0, min_value.dtype)),
max_value = self.backend.numpy.amax(inputs)
min_value = self.backend.numpy.amin(inputs)
condition = self.backend.numpy.logical_and(
self.backend.numpy.greater(
self.backend.cast(depth, max_value.dtype), max_value
),
self.backend.numpy.greater_equal(
min_value, self.backend.cast(0, min_value.dtype)
),
)
if not condition:
raise ValueError(
"Input values must be in the range 0 <= values < num_tokens"
f" with num_tokens={depth}"
)
try:
# Check value range in eager mode only.
condition = bool(condition.__array__())
if not condition:
raise ValueError(
"Input values must be in the range 0 <= values < num_tokens"
f" with num_tokens={depth}"
)
except:
pass
return self._encode_categorical_inputs(
inputs,
@ -164,12 +176,12 @@ class CategoryEncoding(Layer):
):
# In all cases, we should uprank scalar input to a single sample.
if len(inputs.shape) == 0:
inputs = ops.expand_dims(inputs, -1)
inputs = self.backend.numpy.expand_dims(inputs, -1)
# One hot will uprank only if the final output dimension
# is not already 1.
if output_mode == "one_hot":
if len(inputs.shape) > 1 and inputs.shape[-1] != 1:
inputs = ops.expand_dims(inputs, -1)
inputs = self.backend.numpy.expand_dims(inputs, -1)
# TODO(b/190445202): remove output rank restriction.
if len(inputs.shape) > 2:
@ -181,15 +193,14 @@ class CategoryEncoding(Layer):
)
binary_output = output_mode in ("multi_hot", "one_hot")
inputs = ops.cast(inputs, "int32")
inputs = self.backend.cast(inputs, "int32")
if binary_output:
bincounts = ops.one_hot(inputs, num_classes=depth)
bincounts = self.backend.nn.one_hot(inputs, num_classes=depth)
if output_mode == "multi_hot":
bincounts = ops.sum(bincounts, axis=0)
bincounts = self.backend.numpy.sum(bincounts, axis=0)
else:
bincounts = ops.bincount(
bincounts = self.backend.numpy.bincount(
inputs, minlength=depth, weights=count_weights
)
return bincounts

@ -1,4 +1,5 @@
import numpy as np
import tensorflow as tf
from keras_core import layers
from keras_core import testing
@ -51,3 +52,19 @@ class CategoryEncodingTest(testing.TestCase):
output_data = layer(input_data)
self.assertAllClose(expected_output, output_data)
self.assertEqual(expected_output_shape, output_data.shape)
def test_tf_data_compatibility(self):
layer = layers.CategoryEncoding(num_tokens=4, output_mode="one_hot")
input_data = np.array([3, 2, 0, 1])
expected_output = np.array(
[
[0, 0, 0, 1],
[0, 0, 1, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
]
)
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(4).map(layer)
for output in ds.take(1):
output = output.numpy()
self.assertAllClose(output, expected_output)

@ -1,11 +1,11 @@
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
from keras_core.utils import image_utils
@keras_core_export("keras_core.layers.CenterCrop")
class CenterCrop(Layer):
class CenterCrop(TFDataLayer):
"""A preprocessing layer which crops images.
This layers crops the central portion of the images to a target size. If an
@ -29,6 +29,9 @@ class CenterCrop(Layer):
If the input height/width is even and the target height/width is odd (or
inversely), the input image is left-padded by 1 pixel.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args:
height: Integer, the height of the output shape.
width: Integer, the width of the output shape.
@ -99,7 +102,10 @@ class CenterCrop(Layer):
]
return image_utils.smart_resize(
inputs, [self.height, self.width], data_format=self.data_format
inputs,
[self.height, self.width],
data_format=self.data_format,
backend_module=self.backend,
)
def compute_output_shape(self, input_shape):

@ -94,3 +94,11 @@ class CenterCropTest(testing.TestCase, parameterized.TestCase):
size[1],
)(img)
self.assertAllClose(ref_out, out)
def test_tf_data_compatibility(self):
layer = layers.CenterCrop(8, 9)
input_data = np.random.random((2, 10, 12, 3))
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output = output.numpy()
self.assertEqual(list(output.shape), [2, 8, 9, 3])

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.Discretization")
@ -22,6 +23,9 @@ class Discretization(Layer):
with any backend (outside the model itself), which is how we recommend
to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Input shape:
Any array of dimension 2 or higher.
@ -192,7 +196,10 @@ class Discretization(Layer):
def call(self, inputs):
outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow" and tf.executing_eagerly():
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
return outputs

@ -3,6 +3,7 @@ import tensorflow as tf
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.HashedCrossing")
@ -25,6 +26,9 @@ class HashedCrossing(Layer):
with any backend (outside the model itself), which is how we recommend
to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args:
num_bins: Number of hash bins.
output_mode: Specification for the output of the layer. Values can be
@ -100,7 +104,10 @@ class HashedCrossing(Layer):
def call(self, inputs):
outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow" and tf.executing_eagerly():
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
return outputs

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.Hashing")
@ -33,6 +34,9 @@ class Hashing(Layer):
with any backend (outside the model itself), which is how we recommend
to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
**Example (FarmHash64)**
>>> layer = keras_core.layers.Hashing(num_bins=3)
@ -175,7 +179,10 @@ class Hashing(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow" and tf.executing_eagerly():
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
return outputs

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.IntegerLookup")
@ -47,6 +48,9 @@ class IntegerLookup(Layer):
with any backend (outside the model itself), which is how we recommend
to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args:
max_tokens: Maximum size of the vocabulary for this layer. This should
only be specified when adapting the vocabulary or when setting
@ -441,7 +445,10 @@ class IntegerLookup(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow" and tf.executing_eagerly():
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
return outputs

@ -1,17 +1,19 @@
from keras_core import operations as ops
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.backend import random
from keras_core.layers.layer import Layer
from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
@keras_core_export("keras_core.layers.RandomBrightness")
class RandomBrightness(Layer):
class RandomBrightness(TFDataLayer):
"""A preprocessing layer which randomly adjusts brightness during training.
This layer will randomly increase/reduce the brightness for the input RGB
images. At inference time, the output will be identical to the input.
Call the layer with `training=True` to adjust the brightness of the input.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args:
factor: Float or a list/tuple of 2 floats between -1.0 and 1.0. The
factor is used to determine the lower bound and upper bound of the
@ -72,21 +74,21 @@ class RandomBrightness(Layer):
super().__init__(**kwargs)
self._set_factor(factor)
self._set_value_range(value_range)
self._seed = seed
self._generator = random.SeedGenerator(seed)
self.seed = seed
self.generator = backend.random.SeedGenerator(seed)
def _set_value_range(self, value_range):
if not isinstance(value_range, (tuple, list)):
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
self.value_range_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
if len(value_range) != 2:
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
self.value_range_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
self._value_range = sorted(value_range)
self.value_range = sorted(value_range)
def _set_factor(self, factor):
if isinstance(factor, (tuple, list)):
@ -114,7 +116,7 @@ class RandomBrightness(Layer):
)
def call(self, inputs, training=True):
inputs = ops.cast(inputs, self.compute_dtype)
inputs = self.backend.cast(inputs, self.compute_dtype)
if training:
return self._brightness_adjust(inputs)
else:
@ -127,23 +129,30 @@ class RandomBrightness(Layer):
elif rank == 4:
# Keep only the batch dim. This will ensure to have same adjustment
# with in one image, but different across the images.
rgb_delta_shape = [images.shape[0], 1, 1, 1]
rgb_delta_shape = [self.backend.shape(images)[0], 1, 1, 1]
else:
raise ValueError(
"Expected the input image to be rank 3 or 4. Received "
f"inputs.shape = {images.shape}"
f"inputs.shape={images.shape}"
)
rgb_delta = ops.random.uniform(
if backend.backend() != self.backend._backend:
seed_generator = self.backend.random.SeedGenerator(self.seed)
else:
seed_generator = self.generator
rgb_delta = self.backend.random.uniform(
minval=self._factor[0],
maxval=self._factor[1],
shape=rgb_delta_shape,
seed=self._generator,
seed=seed_generator,
)
rgb_delta = rgb_delta * (self._value_range[1] - self._value_range[0])
rgb_delta = ops.cast(rgb_delta, images.dtype)
rgb_delta = rgb_delta * (self.value_range[1] - self.value_range[0])
rgb_delta = self.backend.cast(rgb_delta, images.dtype)
images += rgb_delta
return ops.clip(images, self._value_range[0], self._value_range[1])
return self.backend.numpy.clip(
images, self.value_range[0], self.value_range[1]
)
def compute_output_shape(self, input_shape):
return input_shape
@ -151,8 +160,8 @@ class RandomBrightness(Layer):
def get_config(self):
config = {
"factor": self._factor,
"value_range": self._value_range,
"seed": self._seed,
"value_range": self.value_range,
"seed": self.seed,
}
base_config = super().get_config()
return {**base_config, **config}

@ -1,4 +1,5 @@
import numpy as np
import tensorflow as tf
from keras_core import layers
from keras_core import testing
@ -46,3 +47,10 @@ class RandomBrightnessTest(testing.TestCase):
diff = output - inputs
self.assertTrue(np.amax(diff) <= 0)
self.assertTrue(np.mean(diff) < 0)
def test_tf_data_compatibility(self):
layer = layers.RandomBrightness(factor=0.5, seed=1337)
input_data = np.random.random((2, 8, 8, 3))
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

@ -1,11 +1,10 @@
from keras_core import operations as ops
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.backend import random
from keras_core.layers.layer import Layer
from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
@keras_core_export("keras_core.layers.RandomContrast")
class RandomContrast(Layer):
class RandomContrast(TFDataLayer):
"""A preprocessing layer which randomly adjusts contrast during training.
This layer will randomly adjust the contrast of an image or images
@ -20,6 +19,9 @@ class RandomContrast(Layer):
in integer or floating point dtype.
By default, the layer will output floats.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format.
@ -54,30 +56,34 @@ class RandomContrast(Layer):
f"Received: factor={factor}"
)
self.seed = seed
self.generator = random.SeedGenerator(seed)
self.generator = backend.random.SeedGenerator(seed)
def call(self, inputs, training=True):
inputs = ops.cast(inputs, self.compute_dtype)
inputs = self.backend.cast(inputs, self.compute_dtype)
if training:
factor = ops.random.uniform(
if backend.backend() != self.backend._backend:
seed_generator = self.backend.random.SeedGenerator(self.seed)
else:
seed_generator = self.generator
factor = self.backend.random.uniform(
shape=(),
minval=1.0 - self.lower,
maxval=1.0 + self.upper,
seed=self.generator,
seed=seed_generator,
)
outputs = self._adjust_constrast(inputs, factor)
outputs = ops.clip(outputs, 0, 255)
ops.reshape(outputs, inputs.shape)
outputs = self.backend.numpy.clip(outputs, 0, 255)
self.backend.numpy.reshape(outputs, self.backend.shape(inputs))
return outputs
else:
return inputs
def _adjust_constrast(self, inputs, contrast_factor):
# reduce mean on height
inp_mean = ops.mean(inputs, axis=-3, keepdims=True)
inp_mean = self.backend.numpy.mean(inputs, axis=-3, keepdims=True)
# reduce mean on width
inp_mean = ops.mean(inp_mean, axis=-2, keepdims=True)
inp_mean = self.backend.numpy.mean(inp_mean, axis=-2, keepdims=True)
outputs = (inputs - inp_mean) * contrast_factor + inp_mean
return outputs

@ -1,4 +1,5 @@
import numpy as np
import tensorflow as tf
from keras_core import layers
from keras_core import testing
@ -33,3 +34,10 @@ class RandomContrastTest(testing.TestCase):
actual_outputs = np.clip(outputs, 0, 255)
self.assertAllClose(outputs, actual_outputs)
def test_tf_data_compatibility(self):
layer = layers.RandomContrast(factor=0.5, seed=1337)
input_data = np.random.random((2, 8, 8, 3))
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.RandomCrop")
@ -32,6 +33,9 @@ class RandomCrop(Layer):
with any backend (outside the model itself), which is how we recommend
to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format.
@ -59,12 +63,17 @@ class RandomCrop(Layer):
)
self.supports_masking = False
self.supports_jit = False
self._convert_input_args = False
self._allow_non_tensor_positional_args = True
def call(self, inputs, training=True):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow":
outputs = self.layer.call(inputs, training=training)
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
return outputs

@ -1,4 +1,5 @@
import numpy as np
import tensorflow as tf
from keras_core import layers
from keras_core import testing
@ -63,3 +64,11 @@ class RandomCropTest(testing.TestCase):
supports_masking=False,
run_training_check=False,
)
def test_tf_data_compatibility(self):
layer = layers.RandomCrop(8, 9)
input_data = np.random.random((2, 10, 12, 3))
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output = output.numpy()
self.assertEqual(list(output.shape), [2, 8, 9, 3])

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
HORIZONTAL = "horizontal"
VERTICAL = "vertical"
@ -21,6 +22,9 @@ class RandomFlip(Layer):
of integer or floating point dtype.
By default, the layer will output floats.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format.
@ -51,12 +55,17 @@ class RandomFlip(Layer):
**kwargs,
)
self.supports_jit = False
self._convert_input_args = False
self._allow_non_tensor_positional_args = True
def call(self, inputs, training=True):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow":
outputs = self.layer.call(inputs, training=training)
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
return outputs

@ -1,4 +1,5 @@
import numpy as np
import tensorflow as tf
from absl.testing import parameterized
from keras_core import backend
@ -54,3 +55,12 @@ class RandomFlipTest(testing.TestCase, parameterized.TestCase):
supports_masking=False,
run_training_check=False,
)
def test_tf_data_compatibility(self):
layer = layers.RandomFlip("vertical", seed=42)
input_data = np.array([[[2, 3, 4]], [[5, 6, 7]]])
expected_output = np.array([[[5, 6, 7]], [[2, 3, 4]]])
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output = output.numpy()
self.assertAllClose(output, expected_output)

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.RandomRotation")
@ -29,6 +30,9 @@ class RandomRotation(Layer):
with any backend (outside the model itself), which is how we recommend
to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Input shape:
3D (unbatched) or 4D (batched) tensor with shape:
`(..., height, width, channels)`, in `"channels_last"` format
@ -101,7 +105,10 @@ class RandomRotation(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs, training=training)
if backend.backend() != "tensorflow" and tf.executing_eagerly():
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
return outputs

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.RandomTranslation")
@ -17,6 +18,9 @@ class RandomTranslation(Layer):
of integer or floating point dtype. By default, the layer will output
floats.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args:
height_factor: a float represented as fraction of value, or a tuple of
size 2 representing lower and upper bound for shifting vertically. A
@ -91,7 +95,10 @@ class RandomTranslation(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs, training=training)
if backend.backend() != "tensorflow" and tf.executing_eagerly():
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
return outputs

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.RandomZoom")
@ -25,6 +26,9 @@ class RandomZoom(Layer):
with any backend (outside the model itself), which is how we recommend
to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args:
height_factor: a float represented as fraction of value,
or a tuple of size 2 representing lower and upper bound
@ -114,7 +118,10 @@ class RandomZoom(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs, training=training)
if backend.backend() != "tensorflow" and tf.executing_eagerly():
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
return outputs

@ -1,10 +1,9 @@
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
@keras_core_export("keras_core.layers.Rescaling")
class Rescaling(Layer):
class Rescaling(TFDataLayer):
"""A preprocessing layer which rescales input values to a new range.
This layer rescales every value of an input (often an image) by multiplying
@ -22,6 +21,9 @@ class Rescaling(Layer):
of integer or floating point dtype, and by default the layer will output
floats.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args:
scale: Float, the scale to apply to the inputs.
offset: Float, the offset to apply to the inputs.
@ -36,9 +38,9 @@ class Rescaling(Layer):
def call(self, inputs):
dtype = self.compute_dtype
scale = ops.cast(self.scale, dtype)
offset = ops.cast(self.offset, dtype)
return ops.cast(inputs, dtype) * scale + offset
scale = self.backend.cast(self.scale, dtype)
offset = self.backend.cast(self.offset, dtype)
return self.backend.cast(inputs, dtype) * scale + offset
def compute_output_shape(self, input_shape):
return input_shape

@ -1,4 +1,5 @@
import numpy as np
import tensorflow as tf
from keras_core import layers
from keras_core import testing
@ -62,3 +63,10 @@ class RescalingTest(testing.TestCase):
x = np.random.random((3, 10, 10, 3)) * 255
out = layer(x)
self.assertAllClose(out, x / 255 + 0.5)
def test_tf_data_compatibility(self):
layer = layers.Rescaling(scale=1.0 / 255, offset=0.5)
x = np.random.random((3, 10, 10, 3)) * 255
ds = tf.data.Dataset.from_tensor_slices(x).batch(3).map(layer)
for output in ds.take(1):
output.numpy()

@ -1,12 +1,11 @@
from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer
from keras_core.utils import image_utils
@keras_core_export("keras_core.layers.Resizing")
class Resizing(Layer):
class Resizing(TFDataLayer):
"""A preprocessing layer which resizes images.
This layer resizes an image input to a target height and width. The input
@ -25,6 +24,9 @@ class Resizing(Layer):
or `(..., channels, target_height, target_width)`,
in `"channels_first"` format.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args:
height: Integer, the height of the output shape.
width: Integer, the width of the output shape.
@ -73,9 +75,10 @@ class Resizing(Layer):
size=size,
interpolation=self.interpolation,
data_format=self.data_format,
backend_module=self.backend,
)
else:
outputs = ops.image.resize(
outputs = self.backend.image.resize(
inputs,
size=size,
method=self.interpolation,

@ -142,3 +142,11 @@ class ResizingTest(testing.TestCase, parameterized.TestCase):
size[0], size[1], crop_to_aspect_ratio=crop_to_aspect_ratio
)(img)
self.assertAllClose(ref_out, out)
def test_tf_data_compatibility(self):
layer = layers.Resizing(8, 9)
input_data = np.random.random((2, 10, 12, 3))
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output = output.numpy()
self.assertEqual(list(output.shape), [2, 8, 9, 3])

@ -4,6 +4,7 @@ import tensorflow as tf
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.StringLookup")
@ -46,6 +47,9 @@ class StringLookup(Layer):
with any backend (outside the model itself), which is how we recommend
to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args:
max_tokens: Maximum size of the vocabulary for this layer. This should
only be specified when adapting the vocabulary or when setting
@ -435,7 +439,10 @@ class StringLookup(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow" and tf.executing_eagerly():
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
return outputs

@ -5,6 +5,7 @@ from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.saving import serialization_lib
from keras_core.utils import backend_utils
@keras_core_export("keras_core.layers.TextVectorization")
@ -63,6 +64,9 @@ class TextVectorization(Layer):
with any backend (outside the model itself), which is how we recommend
to use this layer.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args:
max_tokens: Maximum size of the vocabulary for this layer. This should
only be specified when adapting a vocabulary or when setting
@ -358,7 +362,10 @@ class TextVectorization(Layer):
if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)):
inputs = tf.convert_to_tensor(np.array(inputs))
outputs = self.layer.call(inputs)
if backend.backend() != "tensorflow" and tf.executing_eagerly():
if (
backend.backend() != "tensorflow"
and not backend_utils.in_tf_graph()
):
outputs = backend.convert_to_tensor(outputs)
return outputs

@ -0,0 +1,36 @@
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
class TFDataLayer(Layer):
"""Layer that can safely used in a tf.data pipeline.
The `call()` method must solely rely on `self.backend` ops.
Only supports a single input tensor argument.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.backend = backend_utils.DynamicBackend()
self._allow_non_tensor_positional_args = True
def __call__(self, inputs, **kwargs):
if backend_utils.in_tf_graph():
# We're in a TF graph, e.g. a tf.data pipeline.
self.backend.set_backend("tensorflow")
inputs = self.backend.convert_to_tensor(
inputs, dtype=self.compute_dtype
)
switch_convert_input_args = False
if self._convert_input_args:
self._convert_input_args = False
switch_convert_input_args = True
try:
outputs = super().__call__(inputs, **kwargs)
finally:
self.backend.reset()
if switch_convert_input_args:
self._convert_input_args = True
return outputs
return super().__call__(inputs, **kwargs)

@ -0,0 +1,48 @@
import sys
from keras_core import backend as backend_module
from keras_core.backend import jax as jax_backend
from keras_core.backend import tensorflow as tf_backend
from keras_core.backend import torch as torch_backend
def in_tf_graph():
if "tensorflow" in sys.modules:
import tensorflow as tf
return not tf.executing_eagerly()
return False
class DynamicBackend:
"""A class that can be used to switch from one backend to another.
Usage:
```python
backend = DynamicBackend("tensorflow")
y = backend.square(tf.constant(...))
backend.set_backend("jax")
y = backend.square(jax.numpy.array(...))
```
Args:
backend: Initial backend to use (string).
"""
def __init__(self, backend=None):
self._backend = backend or backend_module.backend()
def set_backend(self, backend):
self._backend = backend
def reset(self):
self._backend = backend_module.backend()
def __getattr__(self, name):
if self._backend == "tensorflow":
return getattr(tf_backend, name)
if self._backend == "jax":
return getattr(jax_backend, name)
if self._backend == "torch":
return getattr(torch_backend, name)

@ -299,7 +299,11 @@ def load_img(
def smart_resize(
x, size, interpolation="bilinear", data_format="channels_last"
x,
size,
interpolation="bilinear",
data_format="channels_last",
backend_module=None,
):
"""Resize images to a target size without aspect ratio distortion.
@ -354,6 +358,8 @@ def smart_resize(
Supports `bilinear`, `nearest`, `bicubic`,
`lanczos3`, `lanczos5`.
data_format: `"channels_last"` or `"channels_first"`.
backend_module: Backend module to use (if different from the default
backend).
Returns:
Array with shape `(size[0], size[1], channels)`.
@ -361,11 +367,12 @@ def smart_resize(
and if it was a backend-native tensor,
the output is a backend-native tensor.
"""
backend_module = backend_module or backend
if len(size) != 2:
raise ValueError(
f"Expected `size` to be a tuple of 2 integers, but got: {size}."
)
img = backend.convert_to_tensor(x)
img = backend_module.convert_to_tensor(x)
if len(img.shape) is not None:
if len(img.shape) < 3 or len(img.shape) > 4:
raise ValueError(
@ -373,30 +380,32 @@ def smart_resize(
"channels)`, or `(batch_size, height, width, channels)`, but "
f"got input with incorrect rank, of shape {img.shape}."
)
shape = ops.shape(img)
shape = backend_module.shape(img)
if data_format == "channels_last":
height, width = shape[-3], shape[-2]
else:
height, width = shape[-2], shape[-1]
target_height, target_width = size
crop_height = ops.cast(
ops.cast(width * target_height, "float32") / target_width, "int32"
crop_height = backend_module.cast(
backend_module.cast(width * target_height, "float32") / target_width,
"int32",
)
crop_width = ops.cast(
ops.cast(height * target_width, "float32") / target_height, "int32"
crop_width = backend_module.cast(
backend_module.cast(height * target_width, "float32") / target_height,
"int32",
)
# Set back to input height / width if crop_height / crop_width is not
# smaller.
crop_height = ops.minimum(height, crop_height)
crop_width = ops.minimum(width, crop_width)
crop_height = backend_module.numpy.minimum(height, crop_height)
crop_width = backend_module.numpy.minimum(width, crop_width)
crop_box_hstart = ops.cast(
ops.cast(height - crop_height, "float32") / 2, "int32"
crop_box_hstart = backend_module.cast(
backend_module.cast(height - crop_height, "float32") / 2, "int32"
)
crop_box_wstart = ops.cast(
ops.cast(width - crop_width, "float32") / 2, "int32"
crop_box_wstart = backend_module.cast(
backend_module.cast(width - crop_width, "float32") / 2, "int32"
)
if data_format == "channels_last":
@ -428,7 +437,7 @@ def smart_resize(
crop_box_wstart : crop_box_wstart + crop_width,
]
img = ops.image.resize(
img = backend_module.image.resize(
img, size=size, method=interpolation, data_format=data_format
)