From b6312810118e7a310fb357a24dc1cdafa56b6fff Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Mon, 5 Jun 2023 18:40:27 -0700 Subject: [PATCH] Add cast of inputs to `compute_dtype` in `RescalingLayer`. (#272) Note that this cast exists in the `tf.keras` implementation. --- keras_core/layers/preprocessing/rescaling.py | 2 +- .../layers/preprocessing/rescaling_test.py | 39 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/keras_core/layers/preprocessing/rescaling.py b/keras_core/layers/preprocessing/rescaling.py index b8fbad144..cac268e26 100644 --- a/keras_core/layers/preprocessing/rescaling.py +++ b/keras_core/layers/preprocessing/rescaling.py @@ -38,7 +38,7 @@ class Rescaling(Layer): dtype = self.compute_dtype scale = ops.cast(self.scale, dtype) offset = ops.cast(self.offset, dtype) - return inputs * scale + offset + return ops.cast(inputs, dtype) * scale + offset def compute_output_shape(self, input_shape): return input_shape diff --git a/keras_core/layers/preprocessing/rescaling_test.py b/keras_core/layers/preprocessing/rescaling_test.py index 5c46cbf2b..86873c63e 100644 --- a/keras_core/layers/preprocessing/rescaling_test.py +++ b/keras_core/layers/preprocessing/rescaling_test.py @@ -18,6 +18,45 @@ class RescalingTest(testing.TestCase): supports_masking=True, ) + def test_rescaling_dtypes(self): + # int scale + self.run_layer_test( + layers.Rescaling, + init_kwargs={"scale": 2, "offset": 0.5}, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + # int offset + self.run_layer_test( + layers.Rescaling, + init_kwargs={"scale": 1.0, "offset": 2}, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + # int inputs + self.run_layer_test( + layers.Rescaling, + init_kwargs={"scale": 1.0 / 255, "offset": 0.5}, + input_shape=(2, 3), + input_dtype="int16", + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + def test_rescaling_correctness(self): layer = layers.Rescaling(scale=1.0 / 255, offset=0.5) x = np.random.random((3, 10, 10, 3)) * 255