more numpy ops (#6)

This commit is contained in:
Chen Qian 2023-04-15 08:12:15 -07:00 committed by Francois Chollet
parent 2271973560
commit cb7115768a
2 changed files with 1637 additions and 4 deletions

@ -909,12 +909,14 @@ class Diag(Operation):
if None in x_shape:
output_shape = [None]
else:
shorter_side = min(x_shape[0], x_shape[1])
shorter_side = np.minimum(x_shape[0], x_shape[1])
if self.k > 0:
remaining = x_shape[1] - self.k
else:
remaining = x_shape[0] + self.k
output_shape = [max(0, min(remaining, shorter_side))]
output_shape = [
int(np.maximum(0, np.minimum(remaining, shorter_side)))
]
else:
raise ValueError(
f"`x` must be 1-D or 2-D, but received shape {x.shape}."
@ -959,12 +961,14 @@ class Diagonal(Operation):
if None in shape_2d:
diag_shape = [None]
else:
shorter_side = min(shape_2d[0], shape_2d[1])
shorter_side = np.minimum(shape_2d[0], shape_2d[1])
if self.offset > 0:
remaining = shape_2d[1] - self.offset
else:
remaining = shape_2d[0] + self.offset
diag_shape = [max(0, min(remaining, shorter_side))]
diag_shape = [
int(np.maximum(0, np.minimum(remaining, shorter_side)))
]
output_shape = output_shape + diag_shape
return KerasTensor(output_shape, dtype=x.dtype)
@ -1310,6 +1314,285 @@ def isnan(x):
return backend.execute("isnan", x)
class Less(Operation):
def call(self, x1, x2):
return backend.execute("less", x1, x2)
def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
def less(x1, x2):
if any_symbolic_tensors((x1, x2)):
return Less().symbolic_call(x1, x2)
return backend.execute("less", x1, x2)
class LessEqual(Operation):
def call(self, x1, x2):
return backend.execute("less_equal", x1, x2)
def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
def less_equal(x1, x2):
if any_symbolic_tensors((x1, x2)):
return LessEqual().symbolic_call(x1, x2)
return backend.execute("less_equal", x1, x2)
class Linspace(Operation):
def __init__(
self, num=50, endpoint=True, retstep=False, dtype=float, axis=0
):
super().__init__()
self.num = num
self.endpoint = endpoint
self.retstep = retstep
self.dtype = dtype
self.axis = axis
def call(self, start, stop):
return backend.execute(
"linspace",
start,
stop,
num=self.num,
endpoint=self.endpoint,
retstep=self.retstep,
dtype=self.dtype,
axis=self.axis,
)
def compute_output_spec(self, start, stop):
start_shape = getattr(start, "shape", [])
stop_shape = getattr(stop, "shape", [])
output_shape = broadcast_shapes(start_shape, stop_shape)
if self.axis == -1:
output_shape = output_shape + [self.num]
elif self.axis >= 0:
output_shape = (
output_shape[: self.axis]
+ [self.num]
+ output_shape[self.axis :]
)
else:
output_shape = (
output_shape[: self.axis + 1]
+ [self.num]
+ output_shape[self.axis + 1 :]
)
dtype = self.dtype if self.dtype is not None else start.dtype
if self.retstep:
return (KerasTensor(output_shape, dtype=dtype), None)
return KerasTensor(output_shape, dtype=dtype)
def linspace(
start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0
):
if any_symbolic_tensors((start, stop)):
return Linspace(num, endpoint, retstep, dtype, axis)(start, stop)
return backend.execute(
"linspace",
start,
stop,
num=num,
endpoint=endpoint,
retstep=retstep,
dtype=dtype,
axis=axis,
)
class Log(Operation):
def call(self, x):
return backend.execute("log", x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def log(x):
if any_symbolic_tensors((x,)):
return Log().symbolic_call(x)
return backend.execute("log", x)
class Log10(Operation):
def call(self, x):
return backend.execute("log10", x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def log10(x):
if any_symbolic_tensors((x,)):
return Log10().symbolic_call(x)
return backend.execute("log10", x)
class Log1p(Operation):
def call(self, x):
return backend.execute("log1p", x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def log1p(x):
if any_symbolic_tensors((x,)):
return Log10().symbolic_call(x)
return backend.execute("log1p", x)
class Log2(Operation):
def call(self, x):
return backend.execute("log2", x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def log2(x):
if any_symbolic_tensors((x,)):
return Log2().symbolic_call(x)
return backend.execute("log2", x)
class Logaddexp(Operation):
def call(self, x1, x2):
return backend.execute("logaddexp", x1, x2)
def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
def logaddexp(x1, x2):
if any_symbolic_tensors((x1, x2)):
return Logaddexp().symbolic_call(x1, x2)
return backend.execute("logaddexp", x1, x2)
class LogicalAnd(Operation):
def call(self, x1, x2):
return backend.execute("logical_and", x1, x2)
def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
def logical_and(x1, x2):
if any_symbolic_tensors((x1, x2)):
return LogicalAnd().symbolic_call(x1, x2)
return backend.execute("logical_and", x1, x2)
class LogicalNot(Operation):
def call(self, x):
return backend.execute("logical_not", x)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def logical_not(x):
if any_symbolic_tensors((x,)):
return LogicalNot().symbolic_call(x)
return backend.execute("logical_not", x)
class LogicalOr(Operation):
def call(self, x1, x2):
return backend.execute("logical_or", x1, x2)
def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
def logical_or(x1, x2):
if any_symbolic_tensors((x1, x2)):
return LogicalOr().symbolic_call(x1, x2)
return backend.execute("logical_or", x1, x2)
class Logspace(Operation):
def __init__(self, num=50, endpoint=True, base=10, dtype=float, axis=0):
super().__init__()
self.num = num
self.endpoint = endpoint
self.base = base
self.dtype = dtype
self.axis = axis
def call(self, start, stop):
return backend.execute(
"logspace",
start,
stop,
num=self.num,
endpoint=self.endpoint,
base=self.base,
dtype=self.dtype,
axis=self.axis,
)
def compute_output_spec(self, start, stop):
start_shape = getattr(start, "shape", [])
stop_shape = getattr(stop, "shape", [])
output_shape = broadcast_shapes(start_shape, stop_shape)
if self.axis == -1:
output_shape = output_shape + [self.num]
elif self.axis >= 0:
output_shape = (
output_shape[: self.axis]
+ [self.num]
+ output_shape[self.axis :]
)
else:
output_shape = (
output_shape[: self.axis + 1]
+ [self.num]
+ output_shape[self.axis + 1 :]
)
dtype = self.dtype if self.dtype is not None else start.dtype
return KerasTensor(output_shape, dtype=dtype)
def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):
if any_symbolic_tensors((start, stop)):
return Logspace(num, endpoint, base, dtype, axis)(start, stop)
return backend.execute(
"logspace",
start,
stop,
num=num,
endpoint=endpoint,
base=base,
dtype=dtype,
axis=axis,
)
class Matmul(Operation):
def call(self, x1, x2):
return backend.execute("matmul", x1, x2)
@ -1351,6 +1634,494 @@ def matmul(x1, x2):
return backend.execute("matmul", x1, x2)
class Max(Operation):
def __init__(self, axis=None, keepdims=False):
super().__init__()
if isinstance(axis, int):
self.axis = [axis]
else:
self.axis = axis
self.keepdims = keepdims
def call(self, x):
return backend.execute("max", x, axis=self.axis, keepdims=self.keepdims)
def compute_output_spec(self, x):
return KerasTensor(
reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),
dtype=x.dtype,
)
def max(x, axis=None, keepdims=False):
if any_symbolic_tensors((x,)):
return Max(axis=axis, keepdims=keepdims).symbolic_call(x)
return backend.execute("max", x, axis=axis, keepdims=keepdims)
class Maximum(Operation):
def call(self, x1, x2):
return backend.execute("maximum", x1, x2)
def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
def maximum(x1, x2):
if any_symbolic_tensors((x1, x2)):
return Maximum().symbolic_call(x1, x2)
return backend.execute("maximum", x1, x2)
class Meshgrid(Operation):
def __init__(self, indexing="xy"):
super().__init__()
if indexing not in ("xy", "ij"):
raise ValueError(
"Valid values for `indexing` are 'xy' and 'ij', but received {index}."
)
self.indexing = indexing
def call(self, *x):
return backend.execute("meshgrid", *x, indexing=self.indexing)
def compute_output_spec(self, *x):
output_shape = []
for xi in x:
if len(xi.shape) == 0:
size = 1
else:
if None in xi.shape:
size = None
else:
size = int(np.prod(xi.shape))
output_shape.append(size)
if self.indexing == "ij":
return [KerasTensor(output_shape) for _ in range(len(x))]
tmp = output_shape[0]
output_shape[0] = output_shape[1]
output_shape[1] = tmp
return [KerasTensor(output_shape) for _ in range(len(x))]
def meshgrid(*x, indexing="xy"):
if any_symbolic_tensors(x):
return Meshgrid(indexing=indexing).symbolic_call(*x)
return backend.execute("meshgrid", *x, indexing=indexing)
class Min(Operation):
def __init__(self, axis=None, keepdims=False):
if isinstance(axis, int):
self.axis = [axis]
else:
self.axis = axis
self.keepdims = keepdims
def call(self, x):
return backend.execute("min", x, axis=self.axis, keepdims=self.keepdims)
def compute_output_spec(self, x):
return KerasTensor(
reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),
dtype=x.dtype,
)
def min(x, axis=None, keepdims=False):
if any_symbolic_tensors((x,)):
return Min(axis=axis, keepdims=keepdims).symbolic_call(x)
return backend.execute("min", x, axis=axis, keepdims=keepdims)
class Minimum(Operation):
def call(self, x1, x2):
return backend.execute("minimum", x1, x2)
def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
def minimum(x1, x2):
if any_symbolic_tensors((x1, x2)):
return Minimum().symbolic_call(x1, x2)
return backend.execute("minimum", x1, x2)
class Mod(Operation):
def call(self, x1, x2):
return backend.execute("mod", x1, x2)
def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
def mod(x1, x2):
if any_symbolic_tensors((x1, x2)):
return Mod().symbolic_call(x1, x2)
return backend.execute("mod", x1, x2)
class Moveaxis(Operation):
def __init__(self, source, destination):
super().__init__()
if isinstance(source, int):
self.source = [source]
else:
self.source = source
if isinstance(destination, int):
self.destination = [destination]
else:
self.destination = destination
if len(self.source) != len(self.destination):
raise ValueError(
"`source` and `destination` arguments must have the same "
f"number of elements, but received `source={source}` and "
f"`destination={destination}`."
)
def call(self, x):
return backend.execute("moveaxis", x, self.source, self.destination)
def compute_output_spec(self, x):
x_shape = list(x.shape)
output_shape = [-1 for _ in range(len(x.shape))]
for sc, dst in zip(self.source, self.destination):
output_shape[dst] = x_shape[sc]
x_shape[sc] = -1
i, j = 0, 0
while i < len(output_shape):
while i < len(output_shape) and output_shape[i] != -1:
# Find the first dim unset.
i += 1
while j < len(output_shape) and x_shape[j] == -1:
# Find the first dim not being passed.
j += 1
if i == len(output_shape):
break
output_shape[i] = x_shape[j]
i += 1
j += 1
return KerasTensor(output_shape, dtype=x.dtype)
def moveaxis(x, source, destination):
if any_symbolic_tensors((x,)):
return Moveaxis(source, destination).symbolic_call(x)
return backend.execute(
"moveaxis", x, source=source, destination=destination
)
class Ndim(Operation):
def call(self, x):
return backend.execute(
"ndim",
x,
)
def compute_output_spec(self, x):
return KerasTensor([len(x.shape)])
def ndim(x):
if any_symbolic_tensors((x,)):
return Ndim().symbolic_call(x)
return backend.execute("ndim", x)
class Nonzero(Operation):
def call(self, x):
return backend.execute("nonzero", x)
def compute_output_spec(self, x):
output = []
for _ in range(len(x.shape)):
output.append(KerasTensor([None]))
return tuple(output)
def nonzero(x):
if any_symbolic_tensors((x,)):
return Nonzero().symbolic_call(x)
return backend.execute("nonzero", x)
class NotEqual(Operation):
def call(self, x1, x2):
return backend.execute("not_equal", x1, x2)
def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [])
x2_shape = getattr(x2, "shape", [])
output_shape = broadcast_shapes(x1_shape, x2_shape)
return KerasTensor(output_shape, dtype=x1.dtype)
def not_equal(x1, x2):
if any_symbolic_tensors((x1, x2)):
return NotEqual().symbolic_call(x1, x2)
return backend.execute("not_equal", x1, x2)
class OnesLike(Operation):
def call(self, x, dtype=None):
return backend.execute("ones_like", x, dtype=dtype)
def compute_output_spec(self, x, dtype=None):
return KerasTensor(x.shape, dtype=dtype)
def ones_like(x, dtype=None):
if any_symbolic_tensors((x,)):
return OnesLike().symbolic_call(x, dtype=dtype)
return backend.execute("ones_like", x, dtype=dtype)
class Outer(Operation):
def call(self, x1, x2):
return backend.execute("outer", x1, x2)
def compute_output_spec(self, x1, x2):
x1_shape = getattr(x1, "shape", [1])
x2_shape = getattr(x2, "shape", [1])
if None in x1_shape:
x1_flatten_shape = None
else:
x1_flatten_shape = int(np.prod(x1_shape))
if None in x2_shape:
x2_flatten_shape = None
else:
x2_flatten_shape = int(np.prod(x2_shape))
output_shape = [x1_flatten_shape, x2_flatten_shape]
return KerasTensor(output_shape, dtype=x1.dtype)
def outer(x1, x2):
if any_symbolic_tensors((x1, x2)):
return Outer().symbolic_call(x1, x2)
return backend.execute("outer", x1, x2)
class Pad(Operation):
def __init__(self, pad_width, mode="constant"):
super().__init__()
self.pad_width = self._process_pad_width(pad_width)
self.mode = mode
def _process_pad_width(self, pad_width):
if isinstance(pad_width, int):
return ((pad_width, pad_width),)
if isinstance(pad_width, (tuple, list)) and isinstance(
pad_width[0], int
):
return (pad_width,)
first_len = len(pad_width[0])
for i, pw in enumerate(pad_width):
if len(pw) != first_len:
raise ValueError(
"`pad_width` should be a list of tuples of length 2 or "
f"1, but received {pad_width}."
)
if len(pw) == 1:
pad_width[i] = (pw[0], pw[0])
return pad_width
def call(self, x):
return backend.execute(
"pad", x, pad_width=self.pad_width, mode=self.mode
)
def compute_output_spec(self, x):
output_shape = list(x.shape)
if len(self.pad_width) == 1:
pad_width = [self.pad_width[0] for _ in range(len(output_shape))]
elif len(self.pad_width) == len(output_shape):
pad_width = self.pad_width
else:
raise ValueError(
"`pad_width` must have the same length as `x.shape`, but "
f"received {len(self.pad_width)} and {len(x.shape)}."
)
for i in range(len(output_shape)):
if output_shape[i] is None:
output_shape[i] = None
else:
output_shape[i] += pad_width[i][0] + pad_width[i][1]
return KerasTensor(output_shape, dtype=x.dtype)
def pad(x, pad_width, mode="constant"):
if any_symbolic_tensors((x,)):
return Pad(pad_width, mode=mode).symbolic_call(x)
return backend.execute("pad", x, pad_width, mode=mode)
class Prod(Operation):
def __init__(self, axis=None, keepdims=False, dtype=None):
super().__init__()
if isinstance(axis, int):
self.axis = [axis]
else:
self.axis = axis
self.keepdims = keepdims
self.dtype = dtype
def call(self, x):
return backend.execute(
"prod",
x,
axis=self.axis,
keepdims=self.keepdims,
dtype=self.dtype,
)
def compute_output_spec(self, x):
return KerasTensor(
reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),
dtype=self.dtype,
)
def prod(x, axis=None, keepdims=False, dtype=None):
if any_symbolic_tensors((x,)):
return Prod(axis=axis, keepdims=keepdims, dtype=dtype).symbolic_call(x)
return backend.execute("prod", x, axis=axis, keepdims=keepdims, dtype=dtype)
class Ravel(Operation):
def call(self, x):
return backend.execute("ravel", x)
def compute_output_spec(self, x):
if None in x.shape:
output_shape = [
None,
]
else:
output_shape = [int(np.prod(x.shape))]
return KerasTensor(output_shape, dtype=x.dtype)
def ravel(x):
if any_symbolic_tensors((x,)):
return Ravel().symbolic_call(x)
return backend.execute("ravel", x)
class Real(Operation):
def call(self, x):
return backend.execute("real", x)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
def real(x):
if any_symbolic_tensors((x,)):
return Real().symbolic_call(x)
return backend.execute("real", x)
class Reciprocal(Operation):
def call(self, x):
return backend.execute("reciprocal", x)
def compute_output_spec(self, x):
return KerasTensor(x.shape)
def reciprocal(x):
if any_symbolic_tensors((x,)):
return Reciprocal().symbolic_call(x)
return backend.execute("reciprocal", x)
class Repeat(Operation):
def __init__(self, repeats, axis=None):
super().__init__()
self.axis = axis
self.repeats = repeats
def call(self, x):
return backend.execute("repeat", x, self.repeats, axis=self.axis)
def compute_output_spec(self, x):
x_shape = list(x.shape)
if self.axis is None:
if None in x_shape:
return KerasTensor([None], dtype=x.dtype)
x_flatten_size = int(np.prod(x_shape))
if isinstance(self.repeats, int):
output_shape = [x_flatten_size * self.repeats]
else:
output_shape = [int(np.sum(self.repeats))]
return KerasTensor(output_shape, dtype=x.dtype)
size_on_ax = x_shape[self.axis]
output_shape = x_shape
if isinstance(self.repeats, int):
output_shape[self.axis] = size_on_ax * self.repeats
else:
output_shape[self.axis] = int(np.sum(self.repeats))
return KerasTensor(output_shape, dtype=x.dtype)
def repeat(x, repeats, axis=None):
if any_symbolic_tensors((x,)):
return Repeat(repeats, axis=axis).symbolic_call(x)
return backend.execute("repeat", x, repeats, axis=axis)
class Reshape(Operation):
def __init__(self, new_shape):
super().__init__()
self.new_shape = new_shape
def call(self, x):
return backend.execute("reshape", x, self.new_shape)
def compute_output_spec(self, x):
return KerasTensor(self.new_shape, dtype=x.dtype)
def reshape(x, new_shape):
if any_symbolic_tensors((x,)):
return Reshape(new_shape).symbolic_call(x)
return backend.execute("reshape", x, new_shape)
class Roll(Operation):
def __init__(self, shift, axis=None):
super().__init__()
self.shift = shift
self.axis = axis
def call(self, x):
return backend.execute("roll", x, self.shift, self.axis)
def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)
def roll(x, shift, axis=None):
if any_symbolic_tensors((x,)):
return Roll(shift, axis=axis).symbolic_call(x)
return backend.execute("roll", x, shift, axis=axis)
class Subtract(Operation):
def call(self, x1, x2):
return backend.execute("subtract", x1, x2)

@ -242,6 +242,202 @@ class NumpyTwoInputOpsShapeTest(testing.TestCase):
y = KerasTensor([2, 3, 4])
knp.isclose(x, y)
def test_less(self):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3])
self.assertEqual(knp.less(x, y).shape, (2, 3))
x = KerasTensor([None, 3])
y = KerasTensor([2, None])
self.assertEqual(knp.less(x, y).shape, (2, 3))
x = KerasTensor([2, 3])
self.assertEqual(knp.less(x, 2).shape, (2, 3))
with self.assertRaises(ValueError):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3, 4])
knp.less(x, y)
def test_less_equal(self):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3])
self.assertEqual(knp.less_equal(x, y).shape, (2, 3))
x = KerasTensor([None, 3])
y = KerasTensor([2, None])
self.assertEqual(knp.less_equal(x, y).shape, (2, 3))
x = KerasTensor([2, 3])
self.assertEqual(knp.less_equal(x, 2).shape, (2, 3))
with self.assertRaises(ValueError):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3, 4])
knp.less_equal(x, y)
def test_linspace(self):
start = KerasTensor([2, 3, 4])
stop = KerasTensor([2, 3, 4])
self.assertEqual(knp.linspace(start, stop, 10).shape, (10, 2, 3, 4))
start = KerasTensor([None, 3, 4])
stop = KerasTensor([2, 3, 4])
self.assertEqual(
knp.linspace(start, stop, 10, axis=1).shape, (2, 10, 3, 4)
)
start = KerasTensor([None, 3])
stop = 2
self.assertEqual(
knp.linspace(start, stop, 10, axis=1).shape, (None, 10, 3)
)
with self.assertRaises(ValueError):
start = KerasTensor([2, 3])
stop = KerasTensor([2, 3, 4])
knp.linspace(start, stop)
def test_logical_and(self):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3])
self.assertEqual(knp.logical_and(x, y).shape, (2, 3))
x = KerasTensor([None, 3])
y = KerasTensor([2, None])
self.assertEqual(knp.logical_and(x, y).shape, (2, 3))
x = KerasTensor([2, 3])
self.assertEqual(knp.logical_and(x, 2).shape, (2, 3))
with self.assertRaises(ValueError):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3, 4])
knp.logical_and(x, y)
def test_logical_or(self):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3])
self.assertEqual(knp.logical_or(x, y).shape, (2, 3))
x = KerasTensor([None, 3])
y = KerasTensor([2, None])
self.assertEqual(knp.logical_or(x, y).shape, (2, 3))
x = KerasTensor([2, 3])
self.assertEqual(knp.logical_or(x, 2).shape, (2, 3))
with self.assertRaises(ValueError):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3, 4])
knp.logical_or(x, y)
def test_logspace(self):
start = KerasTensor([2, 3, 4])
stop = KerasTensor([2, 3, 4])
self.assertEqual(knp.logspace(start, stop, 10).shape, (10, 2, 3, 4))
start = KerasTensor([None, 3, 4])
stop = KerasTensor([2, 3, 4])
self.assertEqual(
knp.logspace(start, stop, 10, axis=1).shape, (2, 10, 3, 4)
)
start = KerasTensor([None, 3])
stop = 2
self.assertEqual(
knp.logspace(start, stop, 10, axis=1).shape, (None, 10, 3)
)
with self.assertRaises(ValueError):
start = KerasTensor([2, 3])
stop = KerasTensor([2, 3, 4])
knp.logspace(start, stop)
def test_maximum(self):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3])
self.assertEqual(knp.maximum(x, y).shape, (2, 3))
x = KerasTensor([None, 3])
y = KerasTensor([2, None])
self.assertEqual(knp.maximum(x, y).shape, (2, 3))
x = KerasTensor([2, 3])
self.assertEqual(knp.maximum(x, 2).shape, (2, 3))
with self.assertRaises(ValueError):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3, 4])
knp.maximum(x, y)
def test_minimum(self):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3])
self.assertEqual(knp.minimum(x, y).shape, (2, 3))
x = KerasTensor([None, 3])
y = KerasTensor([2, None])
self.assertEqual(knp.minimum(x, y).shape, (2, 3))
x = KerasTensor([2, 3])
self.assertEqual(knp.minimum(x, 2).shape, (2, 3))
with self.assertRaises(ValueError):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3, 4])
knp.minimum(x, y)
def test_mod(self):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3])
self.assertEqual(knp.mod(x, y).shape, (2, 3))
x = KerasTensor([None, 3])
y = KerasTensor([2, None])
self.assertEqual(knp.mod(x, y).shape, (2, 3))
x = KerasTensor([2, 3])
self.assertEqual(knp.mod(x, 2).shape, (2, 3))
with self.assertRaises(ValueError):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3, 4])
knp.mod(x, y)
def test_not_equal(self):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3])
self.assertEqual(knp.not_equal(x, y).shape, (2, 3))
x = KerasTensor([None, 3])
y = KerasTensor([2, None])
self.assertEqual(knp.not_equal(x, y).shape, (2, 3))
x = KerasTensor([2, 3])
self.assertEqual(knp.not_equal(x, 2).shape, (2, 3))
with self.assertRaises(ValueError):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3, 4])
knp.not_equal(x, y)
def test_outer(self):
x = KerasTensor([3])
y = KerasTensor([4])
self.assertEqual(knp.outer(x, y).shape, (3, 4))
x = KerasTensor([2, 3])
y = KerasTensor([4, 5])
self.assertEqual(knp.outer(x, y).shape, (6, 20))
x = KerasTensor([None, 3])
y = KerasTensor([2, None])
self.assertEqual(knp.outer(x, y).shape, (None, None))
x = KerasTensor([2, 3])
self.assertEqual(knp.outer(x, 2).shape, (6, 1))
class NumpyOneInputOpsShapeTest(testing.TestCase):
def test_mean(self):
@ -660,6 +856,199 @@ class NumpyOneInputOpsShapeTest(testing.TestCase):
x = KerasTensor([None, 3])
self.assertEqual(knp.isnan(x).shape, (None, 3))
def test_log(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.log(x).shape, (2, 3))
x = KerasTensor([None, 3])
self.assertEqual(knp.log(x).shape, (None, 3))
def test_log10(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.log10(x).shape, (2, 3))
x = KerasTensor([None, 3])
self.assertEqual(knp.log10(x).shape, (None, 3))
def test_log1p(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.log1p(x).shape, (2, 3))
x = KerasTensor([None, 3])
self.assertEqual(knp.log1p(x).shape, (None, 3))
def test_log2(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.log2(x).shape, (2, 3))
x = KerasTensor([None, 3])
self.assertEqual(knp.log2(x).shape, (None, 3))
def test_logaddexp(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.logaddexp(x, x).shape, (2, 3))
x = KerasTensor([None, 3])
self.assertEqual(knp.logaddexp(x, x).shape, (None, 3))
def test_logical_not(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.logical_not(x).shape, (2, 3))
x = KerasTensor([None, 3])
self.assertEqual(knp.logical_not(x).shape, (None, 3))
def test_max(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.max(x).shape, ())
x = KerasTensor([None, 3])
self.assertEqual(knp.max(x).shape, ())
def test_meshgrid(self):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3, 4])
z = KerasTensor([2, 3, 4, 5])
self.assertEqual(knp.meshgrid(x, y)[0].shape, (24, 6))
self.assertEqual(knp.meshgrid(x, y)[1].shape, (24, 6))
self.assertEqual(knp.meshgrid(x, y, indexing="ij")[0].shape, (6, 24))
self.assertEqual(
knp.meshgrid(x, y, z, indexing="ij")[0].shape, (6, 24, 120)
)
x = KerasTensor([None, 3])
y = KerasTensor([None, 3])
self.assertEqual(knp.meshgrid(x, y)[0].shape, (None, None))
self.assertEqual(knp.meshgrid(x, y)[1].shape, (None, None))
with self.assertRaises(ValueError):
knp.meshgrid(x, y, indexing="kk")
def test_moveaxis(self):
x = KerasTensor([2, 3, 4, 5])
self.assertEqual(knp.moveaxis(x, 0, -1).shape, (3, 4, 5, 2))
self.assertEqual(knp.moveaxis(x, -1, 0).shape, (5, 2, 3, 4))
self.assertEqual(knp.moveaxis(x, [0, 1], [-1, -2]).shape, (4, 5, 3, 2))
self.assertEqual(knp.moveaxis(x, [0, 1], [1, 0]).shape, (3, 2, 4, 5))
self.assertEqual(knp.moveaxis(x, [0, 1], [-2, -1]).shape, (4, 5, 2, 3))
x = KerasTensor([None, 3, 4, 5])
self.assertEqual(knp.moveaxis(x, 0, -1).shape, (3, 4, 5, None))
self.assertEqual(knp.moveaxis(x, -1, 0).shape, (5, None, 3, 4))
self.assertEqual(
knp.moveaxis(x, [0, 1], [-1, -2]).shape, (4, 5, 3, None)
)
self.assertEqual(knp.moveaxis(x, [0, 1], [1, 0]).shape, (3, None, 4, 5))
self.assertEqual(
knp.moveaxis(x, [0, 1], [-2, -1]).shape, (4, 5, None, 3)
)
def test_ndim(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.ndim(x).shape, (2,))
x = KerasTensor([None, 3])
self.assertEqual(knp.ndim(x).shape, (2,))
def test_nonzero(self):
x = KerasTensor([2, 3])
self.assertEqual(len(knp.nonzero(x)), 2)
self.assertEqual(knp.nonzero(x)[0].shape, (None,))
x = KerasTensor([None, 3])
self.assertEqual(len(knp.nonzero(x)), 2)
self.assertEqual(knp.nonzero(x)[0].shape, (None,))
def test_ones_like(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.ones_like(x).shape, (2, 3))
x = KerasTensor([None, 3])
self.assertEqual(knp.ones_like(x).shape, (None, 3))
def test_pad(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.pad(x, 1).shape, (4, 5))
self.assertEqual(knp.pad(x, (1, 2)).shape, (5, 6))
self.assertEqual(knp.pad(x, ((1, 2), (3, 4))).shape, (5, 10))
x = KerasTensor([None, 3])
self.assertEqual(knp.pad(x, 1).shape, (None, 5))
self.assertEqual(knp.pad(x, (1, 2)).shape, (None, 6))
self.assertEqual(knp.pad(x, ((1, 2), (3, 4))).shape, (None, 10))
x = KerasTensor([None, 3, 3])
self.assertEqual(knp.pad(x, 1).shape, (None, 5, 5))
self.assertEqual(knp.pad(x, (1, 2)).shape, (None, 6, 6))
self.assertEqual(
knp.pad(x, ((1, 2), (3, 4), (5, 6))).shape, (None, 10, 14)
)
with self.assertRaises(ValueError):
x = KerasTensor([2, 3])
knp.pad(x, ((1, 2), (3, 4), (5, 6)))
def test_prod(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.prod(x).shape, ())
self.assertEqual(knp.prod(x, axis=0).shape, (3,))
self.assertEqual(knp.prod(x, axis=1).shape, (2,))
x = KerasTensor([None, 3])
self.assertEqual(knp.prod(x).shape, ())
self.assertEqual(knp.prod(x, axis=0).shape, (3,))
self.assertEqual(knp.prod(x, axis=1, keepdims=True).shape, (None, 1))
def test_ravel(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.ravel(x).shape, (6,))
x = KerasTensor([None, 3])
self.assertEqual(knp.ravel(x).shape, (None,))
def test_real(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.real(x).shape, (2, 3))
x = KerasTensor([None, 3])
self.assertEqual(knp.real(x).shape, (None, 3))
def test_reciprocal(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.reciprocal(x).shape, (2, 3))
x = KerasTensor([None, 3])
self.assertEqual(knp.reciprocal(x).shape, (None, 3))
def test_repeat(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.repeat(x, 2).shape, (12,))
self.assertEqual(knp.repeat(x, 3, axis=1).shape, (2, 9))
self.assertEqual(knp.repeat(x, [1, 2], axis=0).shape, (3, 3))
x = KerasTensor([None, 3])
self.assertEqual(knp.repeat(x, 2).shape, (None,))
self.assertEqual(knp.repeat(x, 3, axis=1).shape, (None, 9))
self.assertEqual(knp.repeat(x, [1, 2], axis=0).shape, (3, 3))
def test_reshape(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.reshape(x, (3, 2)).shape, (3, 2))
x = KerasTensor([None, 3])
self.assertEqual(knp.reshape(x, (3, 2)).shape, (3, 2))
def test_roll(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.roll(x, 1).shape, (2, 3))
self.assertEqual(knp.roll(x, 1, axis=1).shape, (2, 3))
self.assertEqual(knp.roll(x, 1, axis=0).shape, (2, 3))
x = KerasTensor([None, 3])
self.assertEqual(knp.roll(x, 1).shape, (None, 3))
self.assertEqual(knp.roll(x, 1, axis=1).shape, (None, 3))
self.assertEqual(knp.roll(x, 1, axis=0).shape, (None, 3))
class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
def test_add(self):
@ -852,6 +1241,229 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.Isclose()(x, 2)), np.isclose(x, 2))
self.assertAllClose(np.array(knp.Isclose()(2, x)), np.isclose(2, x))
def test_less(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
y = np.array([[4, 5, 6], [3, 2, 1]])
self.assertAllClose(np.array(knp.less(x, y)), np.less(x, y))
self.assertAllClose(np.array(knp.less(x, 2)), np.less(x, 2))
self.assertAllClose(np.array(knp.less(2, x)), np.less(2, x))
self.assertAllClose(np.array(knp.Less()(x, y)), np.less(x, y))
self.assertAllClose(np.array(knp.Less()(x, 2)), np.less(x, 2))
self.assertAllClose(np.array(knp.Less()(2, x)), np.less(2, x))
def test_less_equal(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
y = np.array([[4, 5, 6], [3, 2, 1]])
self.assertAllClose(np.array(knp.less_equal(x, y)), np.less_equal(x, y))
self.assertAllClose(np.array(knp.less_equal(x, 2)), np.less_equal(x, 2))
self.assertAllClose(np.array(knp.less_equal(2, x)), np.less_equal(2, x))
self.assertAllClose(
np.array(knp.LessEqual()(x, y)), np.less_equal(x, y)
)
self.assertAllClose(
np.array(knp.LessEqual()(x, 2)), np.less_equal(x, 2)
)
self.assertAllClose(
np.array(knp.LessEqual()(2, x)), np.less_equal(2, x)
)
def test_linspace(self):
self.assertAllClose(
np.array(knp.linspace(0, 10, 5)), np.linspace(0, 10, 5)
)
self.assertAllClose(
np.array(knp.linspace(0, 10, 5, endpoint=False)),
np.linspace(0, 10, 5, endpoint=False),
)
self.assertAllClose(
np.array(knp.Linspace(num=5)(0, 10)), np.linspace(0, 10, 5)
)
self.assertAllClose(
np.array(knp.Linspace(num=5, endpoint=False)(0, 10)),
np.linspace(0, 10, 5, endpoint=False),
)
start = np.zeros([2, 3, 4])
stop = np.ones([2, 3, 4])
self.assertAllClose(
np.array(knp.linspace(start, stop, 5, retstep=True)[0]),
np.linspace(start, stop, 5, retstep=True)[0],
)
self.assertAllClose(
np.array(
knp.linspace(start, stop, 5, endpoint=False, retstep=True)[0]
),
np.linspace(start, stop, 5, endpoint=False, retstep=True)[0],
)
self.assertAllClose(
np.array(
knp.linspace(
start, stop, 5, endpoint=False, retstep=True, dtype="int32"
)[0]
),
np.linspace(
start, stop, 5, endpoint=False, retstep=True, dtype="int32"
)[0],
)
self.assertAllClose(
np.array(knp.Linspace(5, retstep=True)(start, stop)[0]),
np.linspace(start, stop, 5, retstep=True)[0],
)
self.assertAllClose(
np.array(
knp.Linspace(5, endpoint=False, retstep=True)(start, stop)[0]
),
np.linspace(start, stop, 5, endpoint=False, retstep=True)[0],
)
self.assertAllClose(
np.array(
knp.Linspace(5, endpoint=False, retstep=True, dtype="int32")(
start, stop
)[0]
),
np.linspace(
start, stop, 5, endpoint=False, retstep=True, dtype="int32"
)[0],
)
def test_logical_and(self):
x = np.array([[True, False], [True, True]])
y = np.array([[False, False], [True, False]])
self.assertAllClose(
np.array(knp.logical_and(x, y)), np.logical_and(x, y)
)
self.assertAllClose(
np.array(knp.logical_and(x, True)), np.logical_and(x, True)
)
self.assertAllClose(
np.array(knp.logical_and(True, x)), np.logical_and(True, x)
)
self.assertAllClose(
np.array(knp.LogicalAnd()(x, y)), np.logical_and(x, y)
)
self.assertAllClose(
np.array(knp.LogicalAnd()(x, True)), np.logical_and(x, True)
)
self.assertAllClose(
np.array(knp.LogicalAnd()(True, x)), np.logical_and(True, x)
)
def test_logical_or(self):
x = np.array([[True, False], [True, True]])
y = np.array([[False, False], [True, False]])
self.assertAllClose(np.array(knp.logical_or(x, y)), np.logical_or(x, y))
self.assertAllClose(
np.array(knp.logical_or(x, True)), np.logical_or(x, True)
)
self.assertAllClose(
np.array(knp.logical_or(True, x)), np.logical_or(True, x)
)
self.assertAllClose(
np.array(knp.LogicalOr()(x, y)), np.logical_or(x, y)
)
self.assertAllClose(
np.array(knp.LogicalOr()(x, True)), np.logical_or(x, True)
)
self.assertAllClose(
np.array(knp.LogicalOr()(True, x)), np.logical_or(True, x)
)
def test_logspace(self):
self.assertAllClose(
np.array(knp.logspace(0, 10, 5)), np.logspace(0, 10, 5)
)
self.assertAllClose(
np.array(knp.logspace(0, 10, 5, endpoint=False)),
np.logspace(0, 10, 5, endpoint=False),
)
self.assertAllClose(
np.array(knp.Logspace(num=5)(0, 10)), np.logspace(0, 10, 5)
)
self.assertAllClose(
np.array(knp.Logspace(num=5, endpoint=False)(0, 10)),
np.logspace(0, 10, 5, endpoint=False),
)
start = np.zeros([2, 3, 4])
stop = np.ones([2, 3, 4])
self.assertAllClose(
np.array(knp.logspace(start, stop, 5, base=10)),
np.logspace(start, stop, 5, base=10),
)
self.assertAllClose(
np.array(knp.logspace(start, stop, 5, endpoint=False, base=10)),
np.logspace(start, stop, 5, endpoint=False, base=10),
)
self.assertAllClose(
np.array(knp.Logspace(5, base=10)(start, stop)),
np.logspace(start, stop, 5, base=10),
)
self.assertAllClose(
np.array(knp.Logspace(5, endpoint=False, base=10)(start, stop)),
np.logspace(start, stop, 5, endpoint=False, base=10),
)
def test_maximum(self):
x = np.array([[1, 2], [3, 4]])
y = np.array([[5, 6], [7, 8]])
self.assertAllClose(np.array(knp.maximum(x, y)), np.maximum(x, y))
self.assertAllClose(np.array(knp.maximum(x, 1)), np.maximum(x, 1))
self.assertAllClose(np.array(knp.maximum(1, x)), np.maximum(1, x))
self.assertAllClose(np.array(knp.Maximum()(x, y)), np.maximum(x, y))
self.assertAllClose(np.array(knp.Maximum()(x, 1)), np.maximum(x, 1))
self.assertAllClose(np.array(knp.Maximum()(1, x)), np.maximum(1, x))
def test_minimum(self):
x = np.array([[1, 2], [3, 4]])
y = np.array([[5, 6], [7, 8]])
self.assertAllClose(np.array(knp.minimum(x, y)), np.minimum(x, y))
self.assertAllClose(np.array(knp.minimum(x, 1)), np.minimum(x, 1))
self.assertAllClose(np.array(knp.minimum(1, x)), np.minimum(1, x))
self.assertAllClose(np.array(knp.Minimum()(x, y)), np.minimum(x, y))
self.assertAllClose(np.array(knp.Minimum()(x, 1)), np.minimum(x, 1))
self.assertAllClose(np.array(knp.Minimum()(1, x)), np.minimum(1, x))
def test_mod(self):
x = np.array([[1, 2], [3, 4]])
y = np.array([[5, 6], [7, 8]])
self.assertAllClose(np.array(knp.mod(x, y)), np.mod(x, y))
self.assertAllClose(np.array(knp.mod(x, 1)), np.mod(x, 1))
self.assertAllClose(np.array(knp.mod(1, x)), np.mod(1, x))
self.assertAllClose(np.array(knp.Mod()(x, y)), np.mod(x, y))
self.assertAllClose(np.array(knp.Mod()(x, 1)), np.mod(x, 1))
self.assertAllClose(np.array(knp.Mod()(1, x)), np.mod(1, x))
def test_not_equal(self):
x = np.array([[1, 2], [3, 4]])
y = np.array([[5, 6], [7, 8]])
self.assertAllClose(np.array(knp.not_equal(x, y)), np.not_equal(x, y))
self.assertAllClose(np.array(knp.not_equal(x, 1)), np.not_equal(x, 1))
self.assertAllClose(np.array(knp.not_equal(1, x)), np.not_equal(1, x))
self.assertAllClose(np.array(knp.NotEqual()(x, y)), np.not_equal(x, y))
self.assertAllClose(np.array(knp.NotEqual()(x, 1)), np.not_equal(x, 1))
self.assertAllClose(np.array(knp.NotEqual()(1, x)), np.not_equal(1, x))
def test_outer(self):
x = np.array([1, 2, 3])
y = np.array([4, 5, 6])
self.assertAllClose(np.array(knp.outer(x, y)), np.outer(x, y))
self.assertAllClose(np.array(knp.Outer()(x, y)), np.outer(x, y))
x = np.ones([2, 3, 4])
y = np.ones([2, 3, 4, 5, 6])
self.assertAllClose(np.array(knp.outer(x, y)), np.outer(x, y))
self.assertAllClose(np.array(knp.Outer()(x, y)), np.outer(x, y))
class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
def test_mean(self):
@ -1354,6 +1966,256 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.isnan(x)), np.isnan(x))
self.assertAllClose(np.array(knp.Isnan()(x)), np.isnan(x))
def test_log(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.log(x)), np.log(x))
self.assertAllClose(np.array(knp.Log()(x)), np.log(x))
def test_log10(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.log10(x)), np.log10(x))
self.assertAllClose(np.array(knp.Log10()(x)), np.log10(x))
def test_log1p(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.log1p(x)), np.log1p(x))
self.assertAllClose(np.array(knp.Log1p()(x)), np.log1p(x))
def test_log2(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.log2(x)), np.log2(x))
self.assertAllClose(np.array(knp.Log2()(x)), np.log2(x))
def test_logaddexp(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
y = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.logaddexp(x, y)), np.logaddexp(x, y))
self.assertAllClose(np.array(knp.Logaddexp()(x, y)), np.logaddexp(x, y))
def test_logical_not(self):
x = np.array([[True, False], [False, True]])
self.assertAllClose(np.array(knp.logical_not(x)), np.logical_not(x))
self.assertAllClose(np.array(knp.LogicalNot()(x)), np.logical_not(x))
def test_max(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.max(x)), np.max(x))
self.assertAllClose(np.array(knp.Max()(x)), np.max(x))
self.assertAllClose(np.array(knp.max(x, 0)), np.max(x, 0))
self.assertAllClose(np.array(knp.Max(0)(x)), np.max(x, 0))
self.assertAllClose(np.array(knp.max(x, 1)), np.max(x, 1))
self.assertAllClose(np.array(knp.Max(1)(x)), np.max(x, 1))
def test_min(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.min(x)), np.min(x))
self.assertAllClose(np.array(knp.Min()(x)), np.min(x))
self.assertAllClose(np.array(knp.min(x, 0)), np.min(x, 0))
self.assertAllClose(np.array(knp.Min(0)(x)), np.min(x, 0))
self.assertAllClose(np.array(knp.min(x, 1)), np.min(x, 1))
self.assertAllClose(np.array(knp.Min(1)(x)), np.min(x, 1))
def test_meshgrid(self):
x = np.array([1, 2, 3])
y = np.array([4, 5, 6])
z = np.array([7, 8, 9])
self.assertAllClose(np.array(knp.meshgrid(x, y)), np.meshgrid(x, y))
self.assertAllClose(np.array(knp.meshgrid(x, z)), np.meshgrid(x, z))
self.assertAllClose(
np.array(knp.meshgrid(x, y, z, indexing="ij")),
np.meshgrid(x, y, z, indexing="ij"),
)
self.assertAllClose(np.array(knp.Meshgrid()(x, y)), np.meshgrid(x, y))
self.assertAllClose(np.array(knp.Meshgrid()(x, z)), np.meshgrid(x, z))
self.assertAllClose(
np.array(knp.Meshgrid(indexing="ij")(x, y, z)),
np.meshgrid(x, y, z, indexing="ij"),
)
x = np.ones([1, 2, 3])
y = np.ones([4, 5, 6, 6])
z = np.ones([7, 8])
self.assertAllClose(np.array(knp.meshgrid(x, y)), np.meshgrid(x, y))
self.assertAllClose(np.array(knp.meshgrid(x, z)), np.meshgrid(x, z))
self.assertAllClose(
np.array(knp.meshgrid(x, y, z, indexing="ij")),
np.meshgrid(x, y, z, indexing="ij"),
)
self.assertAllClose(np.array(knp.Meshgrid()(x, y)), np.meshgrid(x, y))
self.assertAllClose(np.array(knp.Meshgrid()(x, z)), np.meshgrid(x, z))
self.assertAllClose(
np.array(knp.Meshgrid(indexing="ij")(x, y, z)),
np.meshgrid(x, y, z, indexing="ij"),
)
def test_moveaxis(self):
x = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
self.assertAllClose(
np.array(knp.moveaxis(x, 0, -1)), np.moveaxis(x, 0, -1)
)
self.assertAllClose(
np.array(knp.moveaxis(x, -1, 0)), np.moveaxis(x, -1, 0)
)
self.assertAllClose(
np.array(knp.moveaxis(x, (0, 1), (1, 0))),
np.moveaxis(x, (0, 1), (1, 0)),
)
self.assertAllClose(
np.array(knp.moveaxis(x, [0, 1, 2], [2, 0, 1])),
np.moveaxis(x, [0, 1, 2], [2, 0, 1]),
)
self.assertAllClose(
np.array(knp.Moveaxis(-1, 0)(x)), np.moveaxis(x, -1, 0)
)
self.assertAllClose(
np.array(knp.Moveaxis((0, 1), (1, 0))(x)),
np.moveaxis(x, (0, 1), (1, 0)),
)
self.assertAllClose(
np.array(knp.Moveaxis([0, 1, 2], [2, 0, 1])(x)),
np.moveaxis(x, [0, 1, 2], [2, 0, 1]),
)
def test_ndim(self):
x = np.array([1, 2, 3])
self.assertEqual(knp.ndim(x), np.ndim(x))
self.assertEqual(knp.Ndim()(x), np.ndim(x))
def test_nonzero(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.nonzero(x)), np.nonzero(x))
self.assertAllClose(np.array(knp.Nonzero()(x)), np.nonzero(x))
def test_ones_like(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.ones_like(x)), np.ones_like(x))
self.assertAllClose(np.array(knp.OnesLike()(x)), np.ones_like(x))
def test_pad(self):
x = np.array([[1, 2], [3, 4]])
self.assertAllClose(
np.array(knp.pad(x, ((1, 1), (1, 1)))), np.pad(x, ((1, 1), (1, 1)))
)
self.assertAllClose(
np.array(knp.pad(x, ((1, 1), (1, 1)))),
np.pad(x, ((1, 1), (1, 1))),
)
self.assertAllClose(
np.array(knp.pad(x, ((1, 1), (1, 1)), mode="reflect")),
np.pad(x, ((1, 1), (1, 1)), mode="reflect"),
)
self.assertAllClose(
np.array(knp.pad(x, ((1, 1), (1, 1)), mode="symmetric")),
np.pad(x, ((1, 1), (1, 1)), mode="symmetric"),
)
self.assertAllClose(
np.array(knp.Pad(((1, 1), (1, 1)))(x)), np.pad(x, ((1, 1), (1, 1)))
)
self.assertAllClose(
np.array(knp.Pad(((1, 1), (1, 1)))(x)),
np.pad(x, ((1, 1), (1, 1))),
)
self.assertAllClose(
np.array(knp.Pad(((1, 1), (1, 1)), mode="reflect")(x)),
np.pad(x, ((1, 1), (1, 1)), mode="reflect"),
)
self.assertAllClose(
np.array(knp.Pad(((1, 1), (1, 1)), mode="symmetric")(x)),
np.pad(x, ((1, 1), (1, 1)), mode="symmetric"),
)
x = np.ones([2, 3, 4, 5])
self.assertAllClose(
np.array(knp.pad(x, ((2, 3), (1, 1), (1, 1), (1, 1)))),
np.pad(x, ((2, 3), (1, 1), (1, 1), (1, 1))),
)
self.assertAllClose(
np.array(knp.Pad(((2, 3), (1, 1), (1, 1), (1, 1)))(x)),
np.pad(x, ((2, 3), (1, 1), (1, 1), (1, 1))),
)
def test_prod(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.prod(x)), np.prod(x))
self.assertAllClose(np.array(knp.prod(x, axis=1)), np.prod(x, axis=1))
self.assertAllClose(
np.array(knp.prod(x, axis=1, keepdims=True)),
np.prod(x, axis=1, keepdims=True),
)
self.assertAllClose(np.array(knp.Prod()(x)), np.prod(x))
self.assertAllClose(np.array(knp.Prod(axis=1)(x)), np.prod(x, axis=1))
self.assertAllClose(
np.array(knp.Prod(axis=1, keepdims=True)(x)),
np.prod(x, axis=1, keepdims=True),
)
def test_ravel(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.ravel(x)), np.ravel(x))
self.assertAllClose(np.array(knp.Ravel()(x)), np.ravel(x))
def test_real(self):
x = np.array([[1, 2, 3 - 3j], [3, 2, 1 + 5j]])
self.assertAllClose(np.array(knp.real(x)), np.real(x))
self.assertAllClose(np.array(knp.Real()(x)), np.real(x))
def test_reciprocal(self):
x = np.array([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]])
self.assertAllClose(np.array(knp.reciprocal(x)), np.reciprocal(x))
self.assertAllClose(np.array(knp.Reciprocal()(x)), np.reciprocal(x))
def test_repeat(self):
x = np.array([[1, 2], [3, 4]])
self.assertAllClose(np.array(knp.repeat(x, 2)), np.repeat(x, 2))
self.assertAllClose(
np.array(knp.repeat(x, 3, axis=1)), np.repeat(x, 3, axis=1)
)
self.assertAllClose(
np.array(knp.repeat(x, [1, 2], axis=-1)),
np.repeat(x, [1, 2], axis=-1),
)
self.assertAllClose(np.array(knp.Repeat(2)(x)), np.repeat(x, 2))
self.assertAllClose(
np.array(knp.Repeat(3, axis=1)(x)), np.repeat(x, 3, axis=1)
)
self.assertAllClose(
np.array(knp.Repeat([1, 2], axis=0)(x)),
np.repeat(x, [1, 2], axis=0),
)
def test_reshape(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(
np.array(knp.reshape(x, [3, 2])), np.reshape(x, [3, 2])
)
self.assertAllClose(
np.array(knp.Reshape([3, 2])(x)), np.reshape(x, [3, 2])
)
def test_roll(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.roll(x, 1)), np.roll(x, 1))
self.assertAllClose(
np.array(knp.roll(x, 1, axis=1)), np.roll(x, 1, axis=1)
)
self.assertAllClose(
np.array(knp.roll(x, -1, axis=0)), np.roll(x, -1, axis=0)
)
self.assertAllClose(np.array(knp.Roll(1)(x)), np.roll(x, 1))
self.assertAllClose(
np.array(knp.Roll(1, axis=1)(x)), np.roll(x, 1, axis=1)
)
self.assertAllClose(
np.array(knp.Roll(-1, axis=0)(x)), np.roll(x, -1, axis=0)
)
class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
def test_ones(self):