Generalize float_size

This commit is contained in:
Francois Chollet 2023-06-05 16:04:39 -07:00
parent 43bded597d
commit 2338b8a6c4
5 changed files with 20 additions and 54 deletions

@ -67,20 +67,12 @@ class InputLayer(Layer):
@keras_core_export(["keras_core.layers.Input", "keras_core.Input"])
def Input(
shape=None,
batch_size=None,
dtype=None,
batch_shape=None,
name=None,
tensor=None,
):
def Input(shape=None, batch_size=None, dtype=None, batch_shape=None, name=None):
layer = InputLayer(
shape=shape,
batch_size=batch_size,
dtype=dtype,
batch_shape=batch_shape,
name=name,
input_tensor=tensor,
)
return layer.output

@ -26,7 +26,7 @@ class RNN(Layer):
section "Note on passing external constants" below.
- A `state_size` attribute. This can be a single integer
(single state) in which case it is the size of the recurrent
state. This can also be a list/tuple of integers
state. This can also be a list of integers
(one size per state).
- A `output_size` attribute, a single integer.
- A `get_initial_state(batch_size=None)`
@ -227,16 +227,16 @@ class RNN(Layer):
raise ValueError(
"state_size must be specified as property on the RNN cell."
)
if not isinstance(state_size, (list, tuple, int)):
if not isinstance(state_size, (list, int)):
raise ValueError(
"state_size must be an integer, or a list/tuple of integers "
"state_size must be an integer, or a list of integers "
"(one for each state tensor)."
)
if isinstance(state_size, int):
self.state_size = [state_size]
self.single_state = True
else:
self.state_size = list(state_size)
self.state_size = state_size
self.single_state = False
def compute_output_shape(self, sequences_shape, initial_state_shape=None):

@ -8,10 +8,10 @@ from keras_core import testing
class OneStateRNNCell(layers.Layer):
def __init__(self, units, state_size=None, **kwargs):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.state_size = state_size if state_size else units
self.state_size = units
def build(self, input_shape):
self.kernel = self.add_weight(
@ -34,10 +34,10 @@ class OneStateRNNCell(layers.Layer):
class TwoStatesRNNCell(layers.Layer):
def __init__(self, units, state_size=None, **kwargs):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.state_size = state_size if state_size else [units, units]
self.state_size = [units, units]
self.output_size = units
def build(self, input_shape):
@ -72,27 +72,7 @@ class RNNTest(testing.TestCase):
def test_basics(self):
self.run_layer_test(
layers.RNN,
init_kwargs={"cell": OneStateRNNCell(5, state_size=5)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 5),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
supports_masking=True,
)
self.run_layer_test(
layers.RNN,
init_kwargs={"cell": OneStateRNNCell(5, state_size=[5])},
input_shape=(3, 2, 4),
expected_output_shape=(3, 5),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
supports_masking=True,
)
self.run_layer_test(
layers.RNN,
init_kwargs={"cell": OneStateRNNCell(5, state_size=(5,))},
init_kwargs={"cell": OneStateRNNCell(5)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 5),
expected_num_trainable_weights=2,
@ -126,17 +106,7 @@ class RNNTest(testing.TestCase):
)
self.run_layer_test(
layers.RNN,
init_kwargs={"cell": TwoStatesRNNCell(5, state_size=[5, 5])},
input_shape=(3, 2, 4),
expected_output_shape=(3, 5),
expected_num_trainable_weights=3,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
supports_masking=True,
)
self.run_layer_test(
layers.RNN,
init_kwargs={"cell": TwoStatesRNNCell(5, state_size=(5, 5))},
init_kwargs={"cell": TwoStatesRNNCell(5)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 5),
expected_num_trainable_weights=3,

@ -2,13 +2,17 @@ from keras_core import backend
from keras_core import operations as ops
def float_dtype_size(dtype):
def dtype_size(dtype):
if dtype in ("bfloat16", "float16"):
return 16
if dtype == "float32":
if dtype in ("float32", "int32"):
return 32
if dtype == "float64":
if dtype in ("float64", "int64"):
return 64
if dtype == "uint8":
return 8
if dtype == "bool":
return 1
raise ValueError(f"Invalid dtype: {dtype}")
@ -32,7 +36,7 @@ def cast_to_common_dtype(tensors):
for x in tensors:
dtype = backend.standardize_dtype(x.dtype)
if is_float(dtype):
if highest_float is None or float_dtype_size(dtype) > highest_float:
if highest_float is None or dtype_size(dtype) > highest_float:
highest_float = dtype
elif dtype == "float16" and highest_float == "bfloat16":
highest_float = "float32"

@ -30,7 +30,7 @@ def weight_memory_size(weights):
for w in unique_weights:
weight_shape = math.prod(w.shape)
dtype = backend.standardize_dtype(w.dtype)
per_param_size = dtype_utils.float_dtype_size(dtype)
per_param_size = dtype_utils.dtype_size(dtype)
total_memory_size += weight_shape * per_param_size
return total_memory_size