fix tests (#14)
This commit is contained in:
parent
89b2ea5977
commit
dbda705cda
@ -746,16 +746,12 @@ class CountNonzero(Operation):
|
||||
def __init__(self, axis=None):
|
||||
super().__init__()
|
||||
if isinstance(axis, int):
|
||||
self.axis = [axis]
|
||||
self.axis = (axis,)
|
||||
else:
|
||||
self.axis = axis
|
||||
|
||||
def call(self, x):
|
||||
return backend.execute(
|
||||
"count_nonzero",
|
||||
x,
|
||||
axis=self.axis,
|
||||
)
|
||||
return backend.execute("count_nonzero", x, axis=self.axis)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
return KerasTensor(
|
||||
@ -1883,16 +1879,7 @@ class Nonzero(Operation):
|
||||
def call(self, x):
|
||||
return backend.execute("nonzero", x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
output = []
|
||||
for _ in range(len(x.shape)):
|
||||
output.append(KerasTensor([None]))
|
||||
return tuple(output)
|
||||
|
||||
|
||||
def nonzero(x):
|
||||
if any_symbolic_tensors((x,)):
|
||||
return Nonzero().symbolic_call(x)
|
||||
return backend.execute("nonzero", x)
|
||||
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user