fix tests (#14)
This commit is contained in:
parent
89b2ea5977
commit
dbda705cda
@ -746,16 +746,12 @@ class CountNonzero(Operation):
|
|||||||
def __init__(self, axis=None):
|
def __init__(self, axis=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if isinstance(axis, int):
|
if isinstance(axis, int):
|
||||||
self.axis = [axis]
|
self.axis = (axis,)
|
||||||
else:
|
else:
|
||||||
self.axis = axis
|
self.axis = axis
|
||||||
|
|
||||||
def call(self, x):
|
def call(self, x):
|
||||||
return backend.execute(
|
return backend.execute("count_nonzero", x, axis=self.axis)
|
||||||
"count_nonzero",
|
|
||||||
x,
|
|
||||||
axis=self.axis,
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_output_spec(self, x):
|
def compute_output_spec(self, x):
|
||||||
return KerasTensor(
|
return KerasTensor(
|
||||||
@ -1883,16 +1879,7 @@ class Nonzero(Operation):
|
|||||||
def call(self, x):
|
def call(self, x):
|
||||||
return backend.execute("nonzero", x)
|
return backend.execute("nonzero", x)
|
||||||
|
|
||||||
def compute_output_spec(self, x):
|
|
||||||
output = []
|
|
||||||
for _ in range(len(x.shape)):
|
|
||||||
output.append(KerasTensor([None]))
|
|
||||||
return tuple(output)
|
|
||||||
|
|
||||||
|
|
||||||
def nonzero(x):
|
def nonzero(x):
|
||||||
if any_symbolic_tensors((x,)):
|
|
||||||
return Nonzero().symbolic_call(x)
|
|
||||||
return backend.execute("nonzero", x)
|
return backend.execute("nonzero", x)
|
||||||
|
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user