Add cast of inputs to compute_dtype
in RescalingLayer
. (#272)
Note that this cast exists in the `tf.keras` implementation.
This commit is contained in:
parent
c2d7c51aba
commit
b631281011
@ -38,7 +38,7 @@ class Rescaling(Layer):
|
||||
dtype = self.compute_dtype
|
||||
scale = ops.cast(self.scale, dtype)
|
||||
offset = ops.cast(self.offset, dtype)
|
||||
return inputs * scale + offset
|
||||
return ops.cast(inputs, dtype) * scale + offset
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
@ -18,6 +18,45 @@ class RescalingTest(testing.TestCase):
|
||||
supports_masking=True,
|
||||
)
|
||||
|
||||
def test_rescaling_dtypes(self):
|
||||
# int scale
|
||||
self.run_layer_test(
|
||||
layers.Rescaling,
|
||||
init_kwargs={"scale": 2, "offset": 0.5},
|
||||
input_shape=(2, 3),
|
||||
expected_output_shape=(2, 3),
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
# int offset
|
||||
self.run_layer_test(
|
||||
layers.Rescaling,
|
||||
init_kwargs={"scale": 1.0, "offset": 2},
|
||||
input_shape=(2, 3),
|
||||
expected_output_shape=(2, 3),
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
# int inputs
|
||||
self.run_layer_test(
|
||||
layers.Rescaling,
|
||||
init_kwargs={"scale": 1.0 / 255, "offset": 0.5},
|
||||
input_shape=(2, 3),
|
||||
input_dtype="int16",
|
||||
expected_output_shape=(2, 3),
|
||||
expected_num_trainable_weights=0,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
expected_num_losses=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
|
||||
def test_rescaling_correctness(self):
|
||||
layer = layers.Rescaling(scale=1.0 / 255, offset=0.5)
|
||||
x = np.random.random((3, 10, 10, 3)) * 255
|
||||
|
Loading…
Reference in New Issue
Block a user