From 186cbf6b7ce1a0264f2ce8d499e81f06c095c23b Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 24 May 2023 21:42:06 -0700 Subject: [PATCH] Merge branches 'main' and 'main' of github.com:keras-team/keras-core --- keras_core/backend/torch/numpy.py | 84 +++++++++++++++++------------ keras_core/operations/numpy_test.py | 58 ++++---------------- 2 files changed, 59 insertions(+), 83 deletions(-) diff --git a/keras_core/backend/torch/numpy.py b/keras_core/backend/torch/numpy.py index 53883ca01..c69906d87 100644 --- a/keras_core/backend/torch/numpy.py +++ b/keras_core/backend/torch/numpy.py @@ -294,7 +294,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1): def dot(x, y): x, y = convert_to_tensor(x), convert_to_tensor(y) - if len(x.shape) == 0 or len(y.shape) == 0: + if x.ndim == 0 or y.ndim == 0: return torch.multiply(x, y) return torch.matmul(x, y) @@ -327,7 +327,7 @@ def expm1(x): def flip(x, axis=None): x = convert_to_tensor(x) if axis is None: - axis = tuple(range(len(x.shape))) + axis = tuple(range(x.ndim)) if isinstance(axis, int): axis = (axis,) return torch.flip(x, dims=axis) @@ -341,23 +341,21 @@ def floor(x): def full(shape, fill_value, dtype=None): dtype = to_torch_dtype(dtype) if hasattr(fill_value, "__len__"): - raise NotImplementedError( - "`torch.full()` only accepts scalars for `fill_value`. " - f"Received: fill_value={fill_value}" - ) - # TODO: implement conversion of shape into repetitions for `torch.tile`` - # return torch.tile(fill_value, reps) + fill_value = convert_to_tensor(fill_value) + reps = shape[-1] // fill_value.shape[-1] # channels-last reduction + reps_by_dim = (*shape[:-1], reps) + return torch.tile(fill_value, reps_by_dim) return torch.full(size=shape, fill_value=fill_value, dtype=dtype) def full_like(x, fill_value, dtype=None): dtype = to_torch_dtype(dtype) if hasattr(fill_value, "__len__"): - raise NotImplementedError( - "`torch.full()` only accepts scalars for `fill_value`." + fill_value = convert_to_tensor(fill_value) + reps_by_dim = tuple( + [x.shape[i] // fill_value.shape[i] for i in range(x.ndim)] ) - # TODO: implement conversion of shape into repetitions for `torch.tile`` - # return torch.tile(fill_value, reps) + return torch.tile(fill_value, reps_by_dim) x = convert_to_tensor(x) return torch.full_like(input=x, fill_value=fill_value, dtype=dtype) @@ -430,15 +428,27 @@ def linspace( "torch.linspace does not support an `axis` argument. " f"Received axis={axis}" ) - dtype = to_torch_dtype(dtype) + dtype = to_torch_dtype(dtype) if endpoint is False: stop = stop - ((stop - start) / num) - linspace = torch.linspace( - start=start, - end=stop, - steps=num, - dtype=dtype, - ) + if hasattr(start, "__len__") and hasattr(stop, "__len__"): + start, stop = convert_to_tensor(start), convert_to_tensor(stop) + stop = cast(stop, dtype) if endpoint is False and dtype else stop + steps = torch.arange(num, dtype=dtype) / (num - 1) + + # reshape `steps` to allow for broadcasting + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # increments from `start` to `stop` in each dimension + linspace = start[None] + steps * (stop - start)[None] + else: + linspace = torch.linspace( + start=start, + end=stop, + steps=num, + dtype=dtype, + ) if retstep is True: return (linspace, num) return linspace @@ -495,13 +505,26 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): dtype = to_torch_dtype(dtype) if endpoint is False: stop = stop - ((stop - start) / num) - logspace = torch.logspace( - start=start, - end=stop, - steps=num, - base=base, - dtype=dtype, - ) + if hasattr(start, "__len__") and hasattr(stop, "__len__"): + start, stop = convert_to_tensor(start), convert_to_tensor(stop) + stop = cast(stop, dtype) if endpoint is False and dtype else stop + steps = torch.arange(num, dtype=dtype) / (num - 1) + + # reshape `steps` to allow for broadcasting + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # increments from `start` to `stop` in each dimension + linspace = start[None] + steps * (stop - start)[None] + logspace = base**linspace + else: + logspace = torch.logspace( + start=start, + end=stop, + steps=num, + base=base, + dtype=dtype, + ) return logspace @@ -715,14 +738,7 @@ def tile(x, repeats): def trace(x, offset=None, axis1=None, axis2=None): x = convert_to_tensor(x) - # TODO: implement support for these arguments - # API divergence between `np.trace()` and `torch.trace()` - if offset or axis1 or axis2: - raise NotImplementedError( - "Arguments not supported by `torch.trace: " - f"offset={offset}, axis1={axis1}, axis2={axis2}" - ) - return torch.trace(x) + return torch.sum(torch.diagonal(x, offset, axis1, axis2), dim=-1) def tri(N, M=None, k=0, dtype="float32"): diff --git a/keras_core/operations/numpy_test.py b/keras_core/operations/numpy_test.py index 323be0b76..c509fbe13 100644 --- a/keras_core/operations/numpy_test.py +++ b/keras_core/operations/numpy_test.py @@ -1713,27 +1713,16 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): np.array(knp.full_like(x, 2, dtype="float32")), np.full_like(x, 2, dtype="float32"), ) + self.assertAllClose( + np.array(knp.full_like(x, np.ones([2, 3]))), + np.full_like(x, np.ones([2, 3])), + ) self.assertAllClose(np.array(knp.FullLike()(x, 2)), np.full_like(x, 2)) self.assertAllClose( np.array(knp.FullLike()(x, 2, dtype="float32")), np.full_like(x, 2, dtype="float32"), ) - - # TODO: implement conversion of shape into repetitions, pass to - # `torch.tile`, since `torch.full()` only accepts scalars - # for `fill_value`." - @pytest.mark.skipif( - backend.backend() == "torch", - reason="`torch.full` only accepts scalars for `fill_value`.", - ) - def test_full_like_without_torch(self): - x = np.array([[1, 2, 3], [3, 2, 1]]) - self.assertAllClose( - np.array(knp.full_like(x, np.ones([2, 3]))), - np.full_like(x, np.ones([2, 3])), - ) - self.assertAllClose( np.array(knp.FullLike()(x, np.ones([2, 3]))), np.full_like(x, np.ones([2, 3])), @@ -1834,13 +1823,6 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): np.linspace(0, 10, 5, endpoint=False), ) - # TODO: torch.linspace does not support tensor or array - # for start/stop, create manual implementation - @pytest.mark.skipif( - backend.backend() == "torch", - reason="`torch.linspace` has no support for array start/stop.", - ) - def test_linspace_without_torch(self): start = np.zeros([2, 3, 4]) stop = np.ones([2, 3, 4]) self.assertAllClose( @@ -1945,15 +1927,9 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): np.logspace(0, 10, 5, endpoint=False), ) - # TODO: torch.logspace does not support tensor or array - # for start/stop, create manual implementation - @pytest.mark.skipif( - backend.backend() == "torch", - reason="`torch.logspace` has no support for array start/stop.", - ) - def test_logspace_without_torch(self): start = np.zeros([2, 3, 4]) stop = np.ones([2, 3, 4]) + self.assertAllClose( np.array(knp.logspace(start, stop, 5, base=10)), np.logspace(start, stop, 5, base=10), @@ -3016,13 +2992,7 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): self.assertAllClose(np.array(knp.tile(x, [2, 3])), np.tile(x, [2, 3])) self.assertAllClose(np.array(knp.Tile([2, 3])(x)), np.tile(x, [2, 3])) - @pytest.mark.skipif( - backend.backend() == "torch", - reason="`torch.split` does not support args `offset`, `axis1`, `axis2`", - ) def test_trace(self): - # TODO: implement `torch.trace` support for arguments `offset`, - # `axis1`, `axis2` and delete NotImplementedError x = np.arange(24).reshape([1, 2, 3, 4]) self.assertAllClose(np.array(knp.trace(x)), np.trace(x)) self.assertAllClose( @@ -3092,25 +3062,15 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase): self.assertAllClose( np.array(knp.full([2, 3], 0.1)), np.full([2, 3], 0.1) ) - - self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0)) - self.assertAllClose( - np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1) - ) - - # TODO: implement conversion of shape into repetitions, pass to - # `torch.tile`, since `torch.full()` only accepts scalars - # for `fill_value`." - @pytest.mark.skipif( - backend.backend() == "torch", - reason="`torch.full` only accepts scalars for `fill_value`.", - ) - def test_full_without_torch(self): self.assertAllClose( np.array(knp.full([2, 3], np.array([1, 4, 5]))), np.full([2, 3], np.array([1, 4, 5])), ) + self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0)) + self.assertAllClose( + np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1) + ) self.assertAllClose( np.array(knp.Full()([2, 3], np.array([1, 4, 5]))), np.full([2, 3], np.array([1, 4, 5])),