Generalize float_size
This commit is contained in:
parent
43bded597d
commit
2338b8a6c4
@ -67,20 +67,12 @@ class InputLayer(Layer):
|
||||
|
||||
|
||||
@keras_core_export(["keras_core.layers.Input", "keras_core.Input"])
|
||||
def Input(
|
||||
shape=None,
|
||||
batch_size=None,
|
||||
dtype=None,
|
||||
batch_shape=None,
|
||||
name=None,
|
||||
tensor=None,
|
||||
):
|
||||
def Input(shape=None, batch_size=None, dtype=None, batch_shape=None, name=None):
|
||||
layer = InputLayer(
|
||||
shape=shape,
|
||||
batch_size=batch_size,
|
||||
dtype=dtype,
|
||||
batch_shape=batch_shape,
|
||||
name=name,
|
||||
input_tensor=tensor,
|
||||
)
|
||||
return layer.output
|
||||
|
@ -26,7 +26,7 @@ class RNN(Layer):
|
||||
section "Note on passing external constants" below.
|
||||
- A `state_size` attribute. This can be a single integer
|
||||
(single state) in which case it is the size of the recurrent
|
||||
state. This can also be a list/tuple of integers
|
||||
state. This can also be a list of integers
|
||||
(one size per state).
|
||||
- A `output_size` attribute, a single integer.
|
||||
- A `get_initial_state(batch_size=None)`
|
||||
@ -227,16 +227,16 @@ class RNN(Layer):
|
||||
raise ValueError(
|
||||
"state_size must be specified as property on the RNN cell."
|
||||
)
|
||||
if not isinstance(state_size, (list, tuple, int)):
|
||||
if not isinstance(state_size, (list, int)):
|
||||
raise ValueError(
|
||||
"state_size must be an integer, or a list/tuple of integers "
|
||||
"state_size must be an integer, or a list of integers "
|
||||
"(one for each state tensor)."
|
||||
)
|
||||
if isinstance(state_size, int):
|
||||
self.state_size = [state_size]
|
||||
self.single_state = True
|
||||
else:
|
||||
self.state_size = list(state_size)
|
||||
self.state_size = state_size
|
||||
self.single_state = False
|
||||
|
||||
def compute_output_shape(self, sequences_shape, initial_state_shape=None):
|
||||
|
@ -8,10 +8,10 @@ from keras_core import testing
|
||||
|
||||
|
||||
class OneStateRNNCell(layers.Layer):
|
||||
def __init__(self, units, state_size=None, **kwargs):
|
||||
def __init__(self, units, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.units = units
|
||||
self.state_size = state_size if state_size else units
|
||||
self.state_size = units
|
||||
|
||||
def build(self, input_shape):
|
||||
self.kernel = self.add_weight(
|
||||
@ -34,10 +34,10 @@ class OneStateRNNCell(layers.Layer):
|
||||
|
||||
|
||||
class TwoStatesRNNCell(layers.Layer):
|
||||
def __init__(self, units, state_size=None, **kwargs):
|
||||
def __init__(self, units, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.units = units
|
||||
self.state_size = state_size if state_size else [units, units]
|
||||
self.state_size = [units, units]
|
||||
self.output_size = units
|
||||
|
||||
def build(self, input_shape):
|
||||
@ -72,27 +72,7 @@ class RNNTest(testing.TestCase):
|
||||
def test_basics(self):
|
||||
self.run_layer_test(
|
||||
layers.RNN,
|
||||
init_kwargs={"cell": OneStateRNNCell(5, state_size=5)},
|
||||
input_shape=(3, 2, 4),
|
||||
expected_output_shape=(3, 5),
|
||||
expected_num_trainable_weights=2,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.RNN,
|
||||
init_kwargs={"cell": OneStateRNNCell(5, state_size=[5])},
|
||||
input_shape=(3, 2, 4),
|
||||
expected_output_shape=(3, 5),
|
||||
expected_num_trainable_weights=2,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.RNN,
|
||||
init_kwargs={"cell": OneStateRNNCell(5, state_size=(5,))},
|
||||
init_kwargs={"cell": OneStateRNNCell(5)},
|
||||
input_shape=(3, 2, 4),
|
||||
expected_output_shape=(3, 5),
|
||||
expected_num_trainable_weights=2,
|
||||
@ -126,17 +106,7 @@ class RNNTest(testing.TestCase):
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.RNN,
|
||||
init_kwargs={"cell": TwoStatesRNNCell(5, state_size=[5, 5])},
|
||||
input_shape=(3, 2, 4),
|
||||
expected_output_shape=(3, 5),
|
||||
expected_num_trainable_weights=3,
|
||||
expected_num_non_trainable_weights=0,
|
||||
expected_num_seed_generators=0,
|
||||
supports_masking=True,
|
||||
)
|
||||
self.run_layer_test(
|
||||
layers.RNN,
|
||||
init_kwargs={"cell": TwoStatesRNNCell(5, state_size=(5, 5))},
|
||||
init_kwargs={"cell": TwoStatesRNNCell(5)},
|
||||
input_shape=(3, 2, 4),
|
||||
expected_output_shape=(3, 5),
|
||||
expected_num_trainable_weights=3,
|
||||
|
@ -2,13 +2,17 @@ from keras_core import backend
|
||||
from keras_core import operations as ops
|
||||
|
||||
|
||||
def float_dtype_size(dtype):
|
||||
def dtype_size(dtype):
|
||||
if dtype in ("bfloat16", "float16"):
|
||||
return 16
|
||||
if dtype == "float32":
|
||||
if dtype in ("float32", "int32"):
|
||||
return 32
|
||||
if dtype == "float64":
|
||||
if dtype in ("float64", "int64"):
|
||||
return 64
|
||||
if dtype == "uint8":
|
||||
return 8
|
||||
if dtype == "bool":
|
||||
return 1
|
||||
raise ValueError(f"Invalid dtype: {dtype}")
|
||||
|
||||
|
||||
@ -32,7 +36,7 @@ def cast_to_common_dtype(tensors):
|
||||
for x in tensors:
|
||||
dtype = backend.standardize_dtype(x.dtype)
|
||||
if is_float(dtype):
|
||||
if highest_float is None or float_dtype_size(dtype) > highest_float:
|
||||
if highest_float is None or dtype_size(dtype) > highest_float:
|
||||
highest_float = dtype
|
||||
elif dtype == "float16" and highest_float == "bfloat16":
|
||||
highest_float = "float32"
|
||||
|
@ -30,7 +30,7 @@ def weight_memory_size(weights):
|
||||
for w in unique_weights:
|
||||
weight_shape = math.prod(w.shape)
|
||||
dtype = backend.standardize_dtype(w.dtype)
|
||||
per_param_size = dtype_utils.float_dtype_size(dtype)
|
||||
per_param_size = dtype_utils.dtype_size(dtype)
|
||||
total_memory_size += weight_shape * per_param_size
|
||||
return total_memory_size
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user