Add cast of inputs to compute_dtype in RescalingLayer. (#272)

Note that this cast exists in the `tf.keras` implementation.
This commit is contained in:
hertschuh 2023-06-05 18:40:27 -07:00 committed by Francois Chollet
parent c2d7c51aba
commit b631281011
2 changed files with 40 additions and 1 deletions

@ -38,7 +38,7 @@ class Rescaling(Layer):
dtype = self.compute_dtype
scale = ops.cast(self.scale, dtype)
offset = ops.cast(self.offset, dtype)
return inputs * scale + offset
return ops.cast(inputs, dtype) * scale + offset
def compute_output_shape(self, input_shape):
return input_shape

@ -18,6 +18,45 @@ class RescalingTest(testing.TestCase):
supports_masking=True,
)
def test_rescaling_dtypes(self):
# int scale
self.run_layer_test(
layers.Rescaling,
init_kwargs={"scale": 2, "offset": 0.5},
input_shape=(2, 3),
expected_output_shape=(2, 3),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
# int offset
self.run_layer_test(
layers.Rescaling,
init_kwargs={"scale": 1.0, "offset": 2},
input_shape=(2, 3),
expected_output_shape=(2, 3),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
# int inputs
self.run_layer_test(
layers.Rescaling,
init_kwargs={"scale": 1.0 / 255, "offset": 0.5},
input_shape=(2, 3),
input_dtype="int16",
expected_output_shape=(2, 3),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_losses=0,
supports_masking=True,
)
def test_rescaling_correctness(self):
layer = layers.Rescaling(scale=1.0 / 255, offset=0.5)
x = np.random.random((3, 10, 10, 3)) * 255