Merge branch 'main' of github.com:keras-team/keras-core

This commit is contained in:
Francois Chollet 2023-06-01 18:49:07 -07:00
parent 572ca0eb82
commit 75aa8fff5e
4 changed files with 94 additions and 37 deletions

@ -1,8 +1,9 @@
from keras_core import backend
from keras_core import operations as ops
from keras_core.backend import image_data_format
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer
from keras_core.operations.operation_utils import compute_pooling_output_shape
from keras_core.utils import argument_validation
class BasePooling(Layer):
@ -21,13 +22,16 @@ class BasePooling(Layer):
):
super().__init__(name=name, **kwargs)
self.pool_size = pool_size
self.strides = pool_size if strides is None else strides
self.pool_size = argument_validation.standardize_tuple(
pool_size, pool_dimensions, "pool_size"
)
strides = pool_size if strides is None else strides
self.strides = argument_validation.standardize_tuple(
strides, pool_dimensions, "strides", allow_zero=True
)
self.pool_mode = pool_mode
self.padding = padding
self.data_format = (
image_data_format() if data_format is None else data_format
)
self.data_format = backend.standardize_data_format(data_format)
self.input_spec = InputSpec(ndim=pool_dimensions + 2)

@ -1514,33 +1514,71 @@ def full_like(x, fill_value, dtype=None):
class GetItem(Operation):
def call(self, x, key):
if not isinstance(key, int):
# TODO: support slicing.
raise ValueError(
"Only scalar int keys are supported at this time. Cannot "
f"process key {key}"
)
return x[key]
def compute_output_spec(self, x, key):
if not isinstance(key, int):
# TODO: support slicing.
remaining_shape = list(x.shape)
new_shape = []
if isinstance(key, int):
remaining_key = [key]
elif isinstance(key, tuple):
remaining_key = list(key)
else:
raise ValueError(
"Only scalar int keys are supported at this time. Cannot "
f"process key {key}"
f"Unsupported key type for array slice. Recieved: `{key}`"
)
if len(x.shape) == 0:
num_ellipses = remaining_key.count(Ellipsis)
if num_ellipses > 1:
raise ValueError(
f"Too many indices for array: array is scalar "
f"but index {key} was requested. A scalar array "
"cannot be indexed."
f"Slice should only have one ellipsis. Recieved: `{key}`"
)
if x.shape[0] is not None and key >= x.shape[0]:
raise ValueError(
f"Array has shape {x.shape} "
f"but out-of-bound index {key} was requested."
)
return KerasTensor(x.shape[1:], dtype=x.dtype)
elif num_ellipses == 0:
# Add an implicit final ellipsis.
remaining_key.append(Ellipsis)
# Consume slice key element by element.
while True:
if not remaining_key:
break
subkey = remaining_key.pop(0)
# Check for `newaxis` and `Ellipsis`.
if subkey == Ellipsis:
# Keep as many slices remain in our key, omitting `newaxis`.
needed = len(remaining_key) - remaining_key.count(np.newaxis)
consumed = len(remaining_shape) - needed
new_shape += remaining_shape[:consumed]
remaining_shape = remaining_shape[consumed:]
continue
# All frameworks follow numpy for newaxis. `np.newaxis == None`.
if subkey == np.newaxis:
new_shape.append(1)
continue
# At this point, we need to consume a new axis from the shape.
if not remaining_shape:
raise ValueError(
f"Array has shape {x.shape} but slice "
f"has to many indices. Recieved: `{key}`"
)
length = remaining_shape.pop(0)
if isinstance(subkey, int):
if length is not None:
index = subkey if subkey >= 0 else subkey + length
if index < 0 or index >= length:
raise ValueError(
f"Array has shape {x.shape} but out-of-bounds "
f"index {key} was requested."
)
elif isinstance(subkey, slice):
if length is not None:
# python3 friendly way to compute a slice length.
new_length = len(range(*subkey.indices(length)))
new_shape.append(new_length)
else:
new_shape.append(length)
else:
raise ValueError(
f"Unsupported key type for array slice. Recieved: `{key}`"
)
return KerasTensor(tuple(new_shape), dtype=x.dtype)
@keras_core_export(
@ -1549,12 +1587,6 @@ class GetItem(Operation):
def get_item(x, key):
if any_symbolic_tensors((x,)):
return GetItem().symbolic_call(x, key)
if not isinstance(key, int):
# TODO: support slicing.
raise ValueError(
"Only scalar int keys are supported at this time. Cannot "
f"process key {key}"
)
return x[key]

@ -882,8 +882,26 @@ class NumpyOneInputOpsDynamicShapeTest(testing.TestCase):
self.assertEqual(knp.floor(x).shape, (None, 3))
def test_get_item(self):
x = KerasTensor([None, None])
self.assertEqual(knp.get_item(x, 5).shape, (None,))
x = KerasTensor([None, 5, 16])
# Simple slice.
sliced = knp.get_item(x, 5)
self.assertEqual(sliced.shape, (5, 16))
# Ellipsis slice.
sliced = knp.get_item(x, np.s_[..., -1])
self.assertEqual(sliced.shape, (None, 5))
# `newaxis` slice.
sliced = knp.get_item(x, np.s_[:, np.newaxis, ...])
self.assertEqual(sliced.shape, (None, 1, 5, 16))
# Strided slice.
sliced = knp.get_item(x, np.s_[:5, 3:, 3:12:2])
self.assertEqual(sliced.shape, (None, 2, 5))
# Error states.
with self.assertRaises(ValueError):
sliced = knp.get_item(x, np.s_[:, 17, :])
with self.assertRaises(ValueError):
sliced = knp.get_item(x, np.s_[..., 5, ...])
with self.assertRaises(ValueError):
sliced = knp.get_item(x, np.s_[:, :, :, :])
def test_hstack(self):
x = KerasTensor([None, 3])

@ -1,3 +1,5 @@
import typing
from tensorflow import nest
from keras_core.backend import KerasTensor
@ -41,9 +43,10 @@ class SymbolicArguments:
return (tensor_dict[self._single_positional_tensor],), {}
def switch_fn(x):
val = tensor_dict.get(x, None)
if val is not None:
return val
if isinstance(x, typing.Hashable):
val = tensor_dict.get(x, None)
if val is not None:
return val
return x
return self.convert(switch_fn)