Merge branches 'main' and 'main' of github.com:keras-team/keras-core
This commit is contained in:
parent
a9f4ccd0d7
commit
186cbf6b7c
@ -294,7 +294,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
|
||||
|
||||
def dot(x, y):
|
||||
x, y = convert_to_tensor(x), convert_to_tensor(y)
|
||||
if len(x.shape) == 0 or len(y.shape) == 0:
|
||||
if x.ndim == 0 or y.ndim == 0:
|
||||
return torch.multiply(x, y)
|
||||
return torch.matmul(x, y)
|
||||
|
||||
@ -327,7 +327,7 @@ def expm1(x):
|
||||
def flip(x, axis=None):
|
||||
x = convert_to_tensor(x)
|
||||
if axis is None:
|
||||
axis = tuple(range(len(x.shape)))
|
||||
axis = tuple(range(x.ndim))
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
return torch.flip(x, dims=axis)
|
||||
@ -341,23 +341,21 @@ def floor(x):
|
||||
def full(shape, fill_value, dtype=None):
|
||||
dtype = to_torch_dtype(dtype)
|
||||
if hasattr(fill_value, "__len__"):
|
||||
raise NotImplementedError(
|
||||
"`torch.full()` only accepts scalars for `fill_value`. "
|
||||
f"Received: fill_value={fill_value}"
|
||||
)
|
||||
# TODO: implement conversion of shape into repetitions for `torch.tile``
|
||||
# return torch.tile(fill_value, reps)
|
||||
fill_value = convert_to_tensor(fill_value)
|
||||
reps = shape[-1] // fill_value.shape[-1] # channels-last reduction
|
||||
reps_by_dim = (*shape[:-1], reps)
|
||||
return torch.tile(fill_value, reps_by_dim)
|
||||
return torch.full(size=shape, fill_value=fill_value, dtype=dtype)
|
||||
|
||||
|
||||
def full_like(x, fill_value, dtype=None):
|
||||
dtype = to_torch_dtype(dtype)
|
||||
if hasattr(fill_value, "__len__"):
|
||||
raise NotImplementedError(
|
||||
"`torch.full()` only accepts scalars for `fill_value`."
|
||||
fill_value = convert_to_tensor(fill_value)
|
||||
reps_by_dim = tuple(
|
||||
[x.shape[i] // fill_value.shape[i] for i in range(x.ndim)]
|
||||
)
|
||||
# TODO: implement conversion of shape into repetitions for `torch.tile``
|
||||
# return torch.tile(fill_value, reps)
|
||||
return torch.tile(fill_value, reps_by_dim)
|
||||
x = convert_to_tensor(x)
|
||||
return torch.full_like(input=x, fill_value=fill_value, dtype=dtype)
|
||||
|
||||
@ -430,15 +428,27 @@ def linspace(
|
||||
"torch.linspace does not support an `axis` argument. "
|
||||
f"Received axis={axis}"
|
||||
)
|
||||
dtype = to_torch_dtype(dtype)
|
||||
dtype = to_torch_dtype(dtype)
|
||||
if endpoint is False:
|
||||
stop = stop - ((stop - start) / num)
|
||||
linspace = torch.linspace(
|
||||
start=start,
|
||||
end=stop,
|
||||
steps=num,
|
||||
dtype=dtype,
|
||||
)
|
||||
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
|
||||
start, stop = convert_to_tensor(start), convert_to_tensor(stop)
|
||||
stop = cast(stop, dtype) if endpoint is False and dtype else stop
|
||||
steps = torch.arange(num, dtype=dtype) / (num - 1)
|
||||
|
||||
# reshape `steps` to allow for broadcasting
|
||||
for i in range(start.ndim):
|
||||
steps = steps.unsqueeze(-1)
|
||||
|
||||
# increments from `start` to `stop` in each dimension
|
||||
linspace = start[None] + steps * (stop - start)[None]
|
||||
else:
|
||||
linspace = torch.linspace(
|
||||
start=start,
|
||||
end=stop,
|
||||
steps=num,
|
||||
dtype=dtype,
|
||||
)
|
||||
if retstep is True:
|
||||
return (linspace, num)
|
||||
return linspace
|
||||
@ -495,13 +505,26 @@ def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
|
||||
dtype = to_torch_dtype(dtype)
|
||||
if endpoint is False:
|
||||
stop = stop - ((stop - start) / num)
|
||||
logspace = torch.logspace(
|
||||
start=start,
|
||||
end=stop,
|
||||
steps=num,
|
||||
base=base,
|
||||
dtype=dtype,
|
||||
)
|
||||
if hasattr(start, "__len__") and hasattr(stop, "__len__"):
|
||||
start, stop = convert_to_tensor(start), convert_to_tensor(stop)
|
||||
stop = cast(stop, dtype) if endpoint is False and dtype else stop
|
||||
steps = torch.arange(num, dtype=dtype) / (num - 1)
|
||||
|
||||
# reshape `steps` to allow for broadcasting
|
||||
for i in range(start.ndim):
|
||||
steps = steps.unsqueeze(-1)
|
||||
|
||||
# increments from `start` to `stop` in each dimension
|
||||
linspace = start[None] + steps * (stop - start)[None]
|
||||
logspace = base**linspace
|
||||
else:
|
||||
logspace = torch.logspace(
|
||||
start=start,
|
||||
end=stop,
|
||||
steps=num,
|
||||
base=base,
|
||||
dtype=dtype,
|
||||
)
|
||||
return logspace
|
||||
|
||||
|
||||
@ -715,14 +738,7 @@ def tile(x, repeats):
|
||||
|
||||
def trace(x, offset=None, axis1=None, axis2=None):
|
||||
x = convert_to_tensor(x)
|
||||
# TODO: implement support for these arguments
|
||||
# API divergence between `np.trace()` and `torch.trace()`
|
||||
if offset or axis1 or axis2:
|
||||
raise NotImplementedError(
|
||||
"Arguments not supported by `torch.trace: "
|
||||
f"offset={offset}, axis1={axis1}, axis2={axis2}"
|
||||
)
|
||||
return torch.trace(x)
|
||||
return torch.sum(torch.diagonal(x, offset, axis1, axis2), dim=-1)
|
||||
|
||||
|
||||
def tri(N, M=None, k=0, dtype="float32"):
|
||||
|
@ -1713,27 +1713,16 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
|
||||
np.array(knp.full_like(x, 2, dtype="float32")),
|
||||
np.full_like(x, 2, dtype="float32"),
|
||||
)
|
||||
self.assertAllClose(
|
||||
np.array(knp.full_like(x, np.ones([2, 3]))),
|
||||
np.full_like(x, np.ones([2, 3])),
|
||||
)
|
||||
|
||||
self.assertAllClose(np.array(knp.FullLike()(x, 2)), np.full_like(x, 2))
|
||||
self.assertAllClose(
|
||||
np.array(knp.FullLike()(x, 2, dtype="float32")),
|
||||
np.full_like(x, 2, dtype="float32"),
|
||||
)
|
||||
|
||||
# TODO: implement conversion of shape into repetitions, pass to
|
||||
# `torch.tile`, since `torch.full()` only accepts scalars
|
||||
# for `fill_value`."
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() == "torch",
|
||||
reason="`torch.full` only accepts scalars for `fill_value`.",
|
||||
)
|
||||
def test_full_like_without_torch(self):
|
||||
x = np.array([[1, 2, 3], [3, 2, 1]])
|
||||
self.assertAllClose(
|
||||
np.array(knp.full_like(x, np.ones([2, 3]))),
|
||||
np.full_like(x, np.ones([2, 3])),
|
||||
)
|
||||
|
||||
self.assertAllClose(
|
||||
np.array(knp.FullLike()(x, np.ones([2, 3]))),
|
||||
np.full_like(x, np.ones([2, 3])),
|
||||
@ -1834,13 +1823,6 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
|
||||
np.linspace(0, 10, 5, endpoint=False),
|
||||
)
|
||||
|
||||
# TODO: torch.linspace does not support tensor or array
|
||||
# for start/stop, create manual implementation
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() == "torch",
|
||||
reason="`torch.linspace` has no support for array start/stop.",
|
||||
)
|
||||
def test_linspace_without_torch(self):
|
||||
start = np.zeros([2, 3, 4])
|
||||
stop = np.ones([2, 3, 4])
|
||||
self.assertAllClose(
|
||||
@ -1945,15 +1927,9 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
|
||||
np.logspace(0, 10, 5, endpoint=False),
|
||||
)
|
||||
|
||||
# TODO: torch.logspace does not support tensor or array
|
||||
# for start/stop, create manual implementation
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() == "torch",
|
||||
reason="`torch.logspace` has no support for array start/stop.",
|
||||
)
|
||||
def test_logspace_without_torch(self):
|
||||
start = np.zeros([2, 3, 4])
|
||||
stop = np.ones([2, 3, 4])
|
||||
|
||||
self.assertAllClose(
|
||||
np.array(knp.logspace(start, stop, 5, base=10)),
|
||||
np.logspace(start, stop, 5, base=10),
|
||||
@ -3016,13 +2992,7 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
|
||||
self.assertAllClose(np.array(knp.tile(x, [2, 3])), np.tile(x, [2, 3]))
|
||||
self.assertAllClose(np.array(knp.Tile([2, 3])(x)), np.tile(x, [2, 3]))
|
||||
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() == "torch",
|
||||
reason="`torch.split` does not support args `offset`, `axis1`, `axis2`",
|
||||
)
|
||||
def test_trace(self):
|
||||
# TODO: implement `torch.trace` support for arguments `offset`,
|
||||
# `axis1`, `axis2` and delete NotImplementedError
|
||||
x = np.arange(24).reshape([1, 2, 3, 4])
|
||||
self.assertAllClose(np.array(knp.trace(x)), np.trace(x))
|
||||
self.assertAllClose(
|
||||
@ -3092,25 +3062,15 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
|
||||
self.assertAllClose(
|
||||
np.array(knp.full([2, 3], 0.1)), np.full([2, 3], 0.1)
|
||||
)
|
||||
|
||||
self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0))
|
||||
self.assertAllClose(
|
||||
np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1)
|
||||
)
|
||||
|
||||
# TODO: implement conversion of shape into repetitions, pass to
|
||||
# `torch.tile`, since `torch.full()` only accepts scalars
|
||||
# for `fill_value`."
|
||||
@pytest.mark.skipif(
|
||||
backend.backend() == "torch",
|
||||
reason="`torch.full` only accepts scalars for `fill_value`.",
|
||||
)
|
||||
def test_full_without_torch(self):
|
||||
self.assertAllClose(
|
||||
np.array(knp.full([2, 3], np.array([1, 4, 5]))),
|
||||
np.full([2, 3], np.array([1, 4, 5])),
|
||||
)
|
||||
|
||||
self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0))
|
||||
self.assertAllClose(
|
||||
np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1)
|
||||
)
|
||||
self.assertAllClose(
|
||||
np.array(knp.Full()([2, 3], np.array([1, 4, 5]))),
|
||||
np.full([2, 3], np.array([1, 4, 5])),
|
||||
|
Loading…
Reference in New Issue
Block a user