Generalize float_size

This commit is contained in:
Francois Chollet 2023-06-05 16:04:39 -07:00
parent 43bded597d
commit 2338b8a6c4
5 changed files with 20 additions and 54 deletions

@ -67,20 +67,12 @@ class InputLayer(Layer):
@keras_core_export(["keras_core.layers.Input", "keras_core.Input"]) @keras_core_export(["keras_core.layers.Input", "keras_core.Input"])
def Input( def Input(shape=None, batch_size=None, dtype=None, batch_shape=None, name=None):
shape=None,
batch_size=None,
dtype=None,
batch_shape=None,
name=None,
tensor=None,
):
layer = InputLayer( layer = InputLayer(
shape=shape, shape=shape,
batch_size=batch_size, batch_size=batch_size,
dtype=dtype, dtype=dtype,
batch_shape=batch_shape, batch_shape=batch_shape,
name=name, name=name,
input_tensor=tensor,
) )
return layer.output return layer.output

@ -26,7 +26,7 @@ class RNN(Layer):
section "Note on passing external constants" below. section "Note on passing external constants" below.
- A `state_size` attribute. This can be a single integer - A `state_size` attribute. This can be a single integer
(single state) in which case it is the size of the recurrent (single state) in which case it is the size of the recurrent
state. This can also be a list/tuple of integers state. This can also be a list of integers
(one size per state). (one size per state).
- A `output_size` attribute, a single integer. - A `output_size` attribute, a single integer.
- A `get_initial_state(batch_size=None)` - A `get_initial_state(batch_size=None)`
@ -227,16 +227,16 @@ class RNN(Layer):
raise ValueError( raise ValueError(
"state_size must be specified as property on the RNN cell." "state_size must be specified as property on the RNN cell."
) )
if not isinstance(state_size, (list, tuple, int)): if not isinstance(state_size, (list, int)):
raise ValueError( raise ValueError(
"state_size must be an integer, or a list/tuple of integers " "state_size must be an integer, or a list of integers "
"(one for each state tensor)." "(one for each state tensor)."
) )
if isinstance(state_size, int): if isinstance(state_size, int):
self.state_size = [state_size] self.state_size = [state_size]
self.single_state = True self.single_state = True
else: else:
self.state_size = list(state_size) self.state_size = state_size
self.single_state = False self.single_state = False
def compute_output_shape(self, sequences_shape, initial_state_shape=None): def compute_output_shape(self, sequences_shape, initial_state_shape=None):

@ -8,10 +8,10 @@ from keras_core import testing
class OneStateRNNCell(layers.Layer): class OneStateRNNCell(layers.Layer):
def __init__(self, units, state_size=None, **kwargs): def __init__(self, units, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.units = units self.units = units
self.state_size = state_size if state_size else units self.state_size = units
def build(self, input_shape): def build(self, input_shape):
self.kernel = self.add_weight( self.kernel = self.add_weight(
@ -34,10 +34,10 @@ class OneStateRNNCell(layers.Layer):
class TwoStatesRNNCell(layers.Layer): class TwoStatesRNNCell(layers.Layer):
def __init__(self, units, state_size=None, **kwargs): def __init__(self, units, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.units = units self.units = units
self.state_size = state_size if state_size else [units, units] self.state_size = [units, units]
self.output_size = units self.output_size = units
def build(self, input_shape): def build(self, input_shape):
@ -72,27 +72,7 @@ class RNNTest(testing.TestCase):
def test_basics(self): def test_basics(self):
self.run_layer_test( self.run_layer_test(
layers.RNN, layers.RNN,
init_kwargs={"cell": OneStateRNNCell(5, state_size=5)}, init_kwargs={"cell": OneStateRNNCell(5)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 5),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
supports_masking=True,
)
self.run_layer_test(
layers.RNN,
init_kwargs={"cell": OneStateRNNCell(5, state_size=[5])},
input_shape=(3, 2, 4),
expected_output_shape=(3, 5),
expected_num_trainable_weights=2,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
supports_masking=True,
)
self.run_layer_test(
layers.RNN,
init_kwargs={"cell": OneStateRNNCell(5, state_size=(5,))},
input_shape=(3, 2, 4), input_shape=(3, 2, 4),
expected_output_shape=(3, 5), expected_output_shape=(3, 5),
expected_num_trainable_weights=2, expected_num_trainable_weights=2,
@ -126,17 +106,7 @@ class RNNTest(testing.TestCase):
) )
self.run_layer_test( self.run_layer_test(
layers.RNN, layers.RNN,
init_kwargs={"cell": TwoStatesRNNCell(5, state_size=[5, 5])}, init_kwargs={"cell": TwoStatesRNNCell(5)},
input_shape=(3, 2, 4),
expected_output_shape=(3, 5),
expected_num_trainable_weights=3,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
supports_masking=True,
)
self.run_layer_test(
layers.RNN,
init_kwargs={"cell": TwoStatesRNNCell(5, state_size=(5, 5))},
input_shape=(3, 2, 4), input_shape=(3, 2, 4),
expected_output_shape=(3, 5), expected_output_shape=(3, 5),
expected_num_trainable_weights=3, expected_num_trainable_weights=3,

@ -2,13 +2,17 @@ from keras_core import backend
from keras_core import operations as ops from keras_core import operations as ops
def float_dtype_size(dtype): def dtype_size(dtype):
if dtype in ("bfloat16", "float16"): if dtype in ("bfloat16", "float16"):
return 16 return 16
if dtype == "float32": if dtype in ("float32", "int32"):
return 32 return 32
if dtype == "float64": if dtype in ("float64", "int64"):
return 64 return 64
if dtype == "uint8":
return 8
if dtype == "bool":
return 1
raise ValueError(f"Invalid dtype: {dtype}") raise ValueError(f"Invalid dtype: {dtype}")
@ -32,7 +36,7 @@ def cast_to_common_dtype(tensors):
for x in tensors: for x in tensors:
dtype = backend.standardize_dtype(x.dtype) dtype = backend.standardize_dtype(x.dtype)
if is_float(dtype): if is_float(dtype):
if highest_float is None or float_dtype_size(dtype) > highest_float: if highest_float is None or dtype_size(dtype) > highest_float:
highest_float = dtype highest_float = dtype
elif dtype == "float16" and highest_float == "bfloat16": elif dtype == "float16" and highest_float == "bfloat16":
highest_float = "float32" highest_float = "float32"

@ -30,7 +30,7 @@ def weight_memory_size(weights):
for w in unique_weights: for w in unique_weights:
weight_shape = math.prod(w.shape) weight_shape = math.prod(w.shape)
dtype = backend.standardize_dtype(w.dtype) dtype = backend.standardize_dtype(w.dtype)
per_param_size = dtype_utils.float_dtype_size(dtype) per_param_size = dtype_utils.dtype_size(dtype)
total_memory_size += weight_shape * per_param_size total_memory_size += weight_shape * per_param_size
return total_memory_size return total_memory_size