Merge branches 'main' and 'main' of github.com:keras-team/keras-core

This commit is contained in:
Francois Chollet 2023-05-24 21:42:06 -07:00
parent a9f4ccd0d7
commit 186cbf6b7c
2 changed files with 59 additions and 83 deletions

@ -294,7 +294,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
def dot(x, y):
x, y = convert_to_tensor(x), convert_to_tensor(y)
if len(x.shape) == 0 or len(y.shape) == 0:
if x.ndim == 0 or y.ndim == 0:
return torch.multiply(x, y)
return torch.matmul(x, y)
@ -327,7 +327,7 @@ def expm1(x):
def flip(x, axis=None):
x = convert_to_tensor(x)
if axis is None:
axis = tuple(range(len(x.shape)))
axis = tuple(range(x.ndim))
if isinstance(axis, int):
axis = (axis,)
return torch.flip(x, dims=axis)
@ -341,23 +341,21 @@ def floor(x):
def full(shape, fill_value, dtype=None):
dtype = to_torch_dtype(dtype)
if hasattr(fill_value, "__len__"):
raise NotImplementedError(
"`torch.full()` only accepts scalars for `fill_value`. "
f"Received: fill_value={fill_value}"
)
# TODO: implement conversion of shape into repetitions for `torch.tile``
# return torch.tile(fill_value, reps)
fill_value = convert_to_tensor(fill_value)
reps = shape[-1] // fill_value.shape[-1] # channels-last reduction
reps_by_dim = (*shape[:-1], reps)
return torch.tile(fill_value, reps_by_dim)
return torch.full(size=shape, fill_value=fill_value, dtype=dtype)
def full_like(x, fill_value, dtype=None):
dtype = to_torch_dtype(dtype)
if hasattr(fill_value, "__len__"):
raise NotImplementedError(
"`torch.full()` only accepts scalars for `fill_value`."
fill_value = convert_to_tensor(fill_value)
reps_by_dim = tuple(
[x.shape[i] // fill_value.shape[i] for i in range(x.ndim)]
)
# TODO: implement conversion of shape into repetitions for `torch.tile``
# return torch.tile(fill_value, reps)
return torch.tile(fill_value, reps_by_dim)
x = convert_to_tensor(x)
return torch.full_like(input=x, fill_value=fill_value, dtype=dtype)
@ -430,15 +428,27 @@ def linspace(
"torch.linspace does not support an `axis` argument. "
f"Received axis={axis}"
)
dtype = to_torch_dtype(dtype)
dtype = to_torch_dtype(dtype)
if endpoint is False:
stop = stop - ((stop - start) / num)
linspace = torch.linspace(
start=start,
end=stop,
steps=num,
dtype=dtype,
)
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
start, stop = convert_to_tensor(start), convert_to_tensor(stop)
stop = cast(stop, dtype) if endpoint is False and dtype else stop
steps = torch.arange(num, dtype=dtype) / (num - 1)
# reshape `steps` to allow for broadcasting
for i in range(start.ndim):
steps = steps.unsqueeze(-1)
# increments from `start` to `stop` in each dimension
linspace = start[None] + steps * (stop - start)[None]
else:
linspace = torch.linspace(
start=start,
end=stop,
steps=num,
dtype=dtype,
)
if retstep is True:
return (linspace, num)
return linspace
@ -495,13 +505,26 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
dtype = to_torch_dtype(dtype)
if endpoint is False:
stop = stop - ((stop - start) / num)
logspace = torch.logspace(
start=start,
end=stop,
steps=num,
base=base,
dtype=dtype,
)
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
start, stop = convert_to_tensor(start), convert_to_tensor(stop)
stop = cast(stop, dtype) if endpoint is False and dtype else stop
steps = torch.arange(num, dtype=dtype) / (num - 1)
# reshape `steps` to allow for broadcasting
for i in range(start.ndim):
steps = steps.unsqueeze(-1)
# increments from `start` to `stop` in each dimension
linspace = start[None] + steps * (stop - start)[None]
logspace = base**linspace
else:
logspace = torch.logspace(
start=start,
end=stop,
steps=num,
base=base,
dtype=dtype,
)
return logspace
@ -715,14 +738,7 @@ def tile(x, repeats):
def trace(x, offset=None, axis1=None, axis2=None):
x = convert_to_tensor(x)
# TODO: implement support for these arguments
# API divergence between `np.trace()` and `torch.trace()`
if offset or axis1 or axis2:
raise NotImplementedError(
"Arguments not supported by `torch.trace: "
f"offset={offset}, axis1={axis1}, axis2={axis2}"
)
return torch.trace(x)
return torch.sum(torch.diagonal(x, offset, axis1, axis2), dim=-1)
def tri(N, M=None, k=0, dtype="float32"):

@ -1713,27 +1713,16 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
np.array(knp.full_like(x, 2, dtype="float32")),
np.full_like(x, 2, dtype="float32"),
)
self.assertAllClose(
np.array(knp.full_like(x, np.ones([2, 3]))),
np.full_like(x, np.ones([2, 3])),
)
self.assertAllClose(np.array(knp.FullLike()(x, 2)), np.full_like(x, 2))
self.assertAllClose(
np.array(knp.FullLike()(x, 2, dtype="float32")),
np.full_like(x, 2, dtype="float32"),
)
# TODO: implement conversion of shape into repetitions, pass to
# `torch.tile`, since `torch.full()` only accepts scalars
# for `fill_value`."
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.full` only accepts scalars for `fill_value`.",
)
def test_full_like_without_torch(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(
np.array(knp.full_like(x, np.ones([2, 3]))),
np.full_like(x, np.ones([2, 3])),
)
self.assertAllClose(
np.array(knp.FullLike()(x, np.ones([2, 3]))),
np.full_like(x, np.ones([2, 3])),
@ -1834,13 +1823,6 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
np.linspace(0, 10, 5, endpoint=False),
)
# TODO: torch.linspace does not support tensor or array
# for start/stop, create manual implementation
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.linspace` has no support for array start/stop.",
)
def test_linspace_without_torch(self):
start = np.zeros([2, 3, 4])
stop = np.ones([2, 3, 4])
self.assertAllClose(
@ -1945,15 +1927,9 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
np.logspace(0, 10, 5, endpoint=False),
)
# TODO: torch.logspace does not support tensor or array
# for start/stop, create manual implementation
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.logspace` has no support for array start/stop.",
)
def test_logspace_without_torch(self):
start = np.zeros([2, 3, 4])
stop = np.ones([2, 3, 4])
self.assertAllClose(
np.array(knp.logspace(start, stop, 5, base=10)),
np.logspace(start, stop, 5, base=10),
@ -3016,13 +2992,7 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.tile(x, [2, 3])), np.tile(x, [2, 3]))
self.assertAllClose(np.array(knp.Tile([2, 3])(x)), np.tile(x, [2, 3]))
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.split` does not support args `offset`, `axis1`, `axis2`",
)
def test_trace(self):
# TODO: implement `torch.trace` support for arguments `offset`,
# `axis1`, `axis2` and delete NotImplementedError
x = np.arange(24).reshape([1, 2, 3, 4])
self.assertAllClose(np.array(knp.trace(x)), np.trace(x))
self.assertAllClose(
@ -3092,25 +3062,15 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(
np.array(knp.full([2, 3], 0.1)), np.full([2, 3], 0.1)
)
self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0))
self.assertAllClose(
np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1)
)
# TODO: implement conversion of shape into repetitions, pass to
# `torch.tile`, since `torch.full()` only accepts scalars
# for `fill_value`."
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.full` only accepts scalars for `fill_value`.",
)
def test_full_without_torch(self):
self.assertAllClose(
np.array(knp.full([2, 3], np.array([1, 4, 5]))),
np.full([2, 3], np.array([1, 4, 5])),
)
self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0))
self.assertAllClose(
np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1)
)
self.assertAllClose(
np.array(knp.Full()([2, 3], np.array([1, 4, 5]))),
np.full([2, 3], np.array([1, 4, 5])),