Generalize float_size
This commit is contained in:
parent
43bded597d
commit
2338b8a6c4
@ -67,20 +67,12 @@ class InputLayer(Layer):
|
|||||||
|
|
||||||
|
|
||||||
@keras_core_export(["keras_core.layers.Input", "keras_core.Input"])
|
@keras_core_export(["keras_core.layers.Input", "keras_core.Input"])
|
||||||
def Input(
|
def Input(shape=None, batch_size=None, dtype=None, batch_shape=None, name=None):
|
||||||
shape=None,
|
|
||||||
batch_size=None,
|
|
||||||
dtype=None,
|
|
||||||
batch_shape=None,
|
|
||||||
name=None,
|
|
||||||
tensor=None,
|
|
||||||
):
|
|
||||||
layer = InputLayer(
|
layer = InputLayer(
|
||||||
shape=shape,
|
shape=shape,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
batch_shape=batch_shape,
|
batch_shape=batch_shape,
|
||||||
name=name,
|
name=name,
|
||||||
input_tensor=tensor,
|
|
||||||
)
|
)
|
||||||
return layer.output
|
return layer.output
|
||||||
|
@ -26,7 +26,7 @@ class RNN(Layer):
|
|||||||
section "Note on passing external constants" below.
|
section "Note on passing external constants" below.
|
||||||
- A `state_size` attribute. This can be a single integer
|
- A `state_size` attribute. This can be a single integer
|
||||||
(single state) in which case it is the size of the recurrent
|
(single state) in which case it is the size of the recurrent
|
||||||
state. This can also be a list/tuple of integers
|
state. This can also be a list of integers
|
||||||
(one size per state).
|
(one size per state).
|
||||||
- A `output_size` attribute, a single integer.
|
- A `output_size` attribute, a single integer.
|
||||||
- A `get_initial_state(batch_size=None)`
|
- A `get_initial_state(batch_size=None)`
|
||||||
@ -227,16 +227,16 @@ class RNN(Layer):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"state_size must be specified as property on the RNN cell."
|
"state_size must be specified as property on the RNN cell."
|
||||||
)
|
)
|
||||||
if not isinstance(state_size, (list, tuple, int)):
|
if not isinstance(state_size, (list, int)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"state_size must be an integer, or a list/tuple of integers "
|
"state_size must be an integer, or a list of integers "
|
||||||
"(one for each state tensor)."
|
"(one for each state tensor)."
|
||||||
)
|
)
|
||||||
if isinstance(state_size, int):
|
if isinstance(state_size, int):
|
||||||
self.state_size = [state_size]
|
self.state_size = [state_size]
|
||||||
self.single_state = True
|
self.single_state = True
|
||||||
else:
|
else:
|
||||||
self.state_size = list(state_size)
|
self.state_size = state_size
|
||||||
self.single_state = False
|
self.single_state = False
|
||||||
|
|
||||||
def compute_output_shape(self, sequences_shape, initial_state_shape=None):
|
def compute_output_shape(self, sequences_shape, initial_state_shape=None):
|
||||||
|
@ -8,10 +8,10 @@ from keras_core import testing
|
|||||||
|
|
||||||
|
|
||||||
class OneStateRNNCell(layers.Layer):
|
class OneStateRNNCell(layers.Layer):
|
||||||
def __init__(self, units, state_size=None, **kwargs):
|
def __init__(self, units, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.units = units
|
self.units = units
|
||||||
self.state_size = state_size if state_size else units
|
self.state_size = units
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
self.kernel = self.add_weight(
|
self.kernel = self.add_weight(
|
||||||
@ -34,10 +34,10 @@ class OneStateRNNCell(layers.Layer):
|
|||||||
|
|
||||||
|
|
||||||
class TwoStatesRNNCell(layers.Layer):
|
class TwoStatesRNNCell(layers.Layer):
|
||||||
def __init__(self, units, state_size=None, **kwargs):
|
def __init__(self, units, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.units = units
|
self.units = units
|
||||||
self.state_size = state_size if state_size else [units, units]
|
self.state_size = [units, units]
|
||||||
self.output_size = units
|
self.output_size = units
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
@ -72,27 +72,7 @@ class RNNTest(testing.TestCase):
|
|||||||
def test_basics(self):
|
def test_basics(self):
|
||||||
self.run_layer_test(
|
self.run_layer_test(
|
||||||
layers.RNN,
|
layers.RNN,
|
||||||
init_kwargs={"cell": OneStateRNNCell(5, state_size=5)},
|
init_kwargs={"cell": OneStateRNNCell(5)},
|
||||||
input_shape=(3, 2, 4),
|
|
||||||
expected_output_shape=(3, 5),
|
|
||||||
expected_num_trainable_weights=2,
|
|
||||||
expected_num_non_trainable_weights=0,
|
|
||||||
expected_num_seed_generators=0,
|
|
||||||
supports_masking=True,
|
|
||||||
)
|
|
||||||
self.run_layer_test(
|
|
||||||
layers.RNN,
|
|
||||||
init_kwargs={"cell": OneStateRNNCell(5, state_size=[5])},
|
|
||||||
input_shape=(3, 2, 4),
|
|
||||||
expected_output_shape=(3, 5),
|
|
||||||
expected_num_trainable_weights=2,
|
|
||||||
expected_num_non_trainable_weights=0,
|
|
||||||
expected_num_seed_generators=0,
|
|
||||||
supports_masking=True,
|
|
||||||
)
|
|
||||||
self.run_layer_test(
|
|
||||||
layers.RNN,
|
|
||||||
init_kwargs={"cell": OneStateRNNCell(5, state_size=(5,))},
|
|
||||||
input_shape=(3, 2, 4),
|
input_shape=(3, 2, 4),
|
||||||
expected_output_shape=(3, 5),
|
expected_output_shape=(3, 5),
|
||||||
expected_num_trainable_weights=2,
|
expected_num_trainable_weights=2,
|
||||||
@ -126,17 +106,7 @@ class RNNTest(testing.TestCase):
|
|||||||
)
|
)
|
||||||
self.run_layer_test(
|
self.run_layer_test(
|
||||||
layers.RNN,
|
layers.RNN,
|
||||||
init_kwargs={"cell": TwoStatesRNNCell(5, state_size=[5, 5])},
|
init_kwargs={"cell": TwoStatesRNNCell(5)},
|
||||||
input_shape=(3, 2, 4),
|
|
||||||
expected_output_shape=(3, 5),
|
|
||||||
expected_num_trainable_weights=3,
|
|
||||||
expected_num_non_trainable_weights=0,
|
|
||||||
expected_num_seed_generators=0,
|
|
||||||
supports_masking=True,
|
|
||||||
)
|
|
||||||
self.run_layer_test(
|
|
||||||
layers.RNN,
|
|
||||||
init_kwargs={"cell": TwoStatesRNNCell(5, state_size=(5, 5))},
|
|
||||||
input_shape=(3, 2, 4),
|
input_shape=(3, 2, 4),
|
||||||
expected_output_shape=(3, 5),
|
expected_output_shape=(3, 5),
|
||||||
expected_num_trainable_weights=3,
|
expected_num_trainable_weights=3,
|
||||||
|
@ -2,13 +2,17 @@ from keras_core import backend
|
|||||||
from keras_core import operations as ops
|
from keras_core import operations as ops
|
||||||
|
|
||||||
|
|
||||||
def float_dtype_size(dtype):
|
def dtype_size(dtype):
|
||||||
if dtype in ("bfloat16", "float16"):
|
if dtype in ("bfloat16", "float16"):
|
||||||
return 16
|
return 16
|
||||||
if dtype == "float32":
|
if dtype in ("float32", "int32"):
|
||||||
return 32
|
return 32
|
||||||
if dtype == "float64":
|
if dtype in ("float64", "int64"):
|
||||||
return 64
|
return 64
|
||||||
|
if dtype == "uint8":
|
||||||
|
return 8
|
||||||
|
if dtype == "bool":
|
||||||
|
return 1
|
||||||
raise ValueError(f"Invalid dtype: {dtype}")
|
raise ValueError(f"Invalid dtype: {dtype}")
|
||||||
|
|
||||||
|
|
||||||
@ -32,7 +36,7 @@ def cast_to_common_dtype(tensors):
|
|||||||
for x in tensors:
|
for x in tensors:
|
||||||
dtype = backend.standardize_dtype(x.dtype)
|
dtype = backend.standardize_dtype(x.dtype)
|
||||||
if is_float(dtype):
|
if is_float(dtype):
|
||||||
if highest_float is None or float_dtype_size(dtype) > highest_float:
|
if highest_float is None or dtype_size(dtype) > highest_float:
|
||||||
highest_float = dtype
|
highest_float = dtype
|
||||||
elif dtype == "float16" and highest_float == "bfloat16":
|
elif dtype == "float16" and highest_float == "bfloat16":
|
||||||
highest_float = "float32"
|
highest_float = "float32"
|
||||||
|
@ -30,7 +30,7 @@ def weight_memory_size(weights):
|
|||||||
for w in unique_weights:
|
for w in unique_weights:
|
||||||
weight_shape = math.prod(w.shape)
|
weight_shape = math.prod(w.shape)
|
||||||
dtype = backend.standardize_dtype(w.dtype)
|
dtype = backend.standardize_dtype(w.dtype)
|
||||||
per_param_size = dtype_utils.float_dtype_size(dtype)
|
per_param_size = dtype_utils.dtype_size(dtype)
|
||||||
total_memory_size += weight_shape * per_param_size
|
total_memory_size += weight_shape * per_param_size
|
||||||
return total_memory_size
|
return total_memory_size
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user