fix tests (#14)

This commit is contained in:
Chen Qian 2023-04-18 15:10:25 -07:00 committed by Francois Chollet
parent 89b2ea5977
commit dbda705cda
2 changed files with 629 additions and 507 deletions

@ -746,16 +746,12 @@ class CountNonzero(Operation):
def __init__(self, axis=None): def __init__(self, axis=None):
super().__init__() super().__init__()
if isinstance(axis, int): if isinstance(axis, int):
self.axis = [axis] self.axis = (axis,)
else: else:
self.axis = axis self.axis = axis
def call(self, x): def call(self, x):
return backend.execute( return backend.execute("count_nonzero", x, axis=self.axis)
"count_nonzero",
x,
axis=self.axis,
)
def compute_output_spec(self, x): def compute_output_spec(self, x):
return KerasTensor( return KerasTensor(
@ -1883,16 +1879,7 @@ class Nonzero(Operation):
def call(self, x): def call(self, x):
return backend.execute("nonzero", x) return backend.execute("nonzero", x)
def compute_output_spec(self, x):
output = []
for _ in range(len(x.shape)):
output.append(KerasTensor([None]))
return tuple(output)
def nonzero(x): def nonzero(x):
if any_symbolic_tensors((x,)):
return Nonzero().symbolic_call(x)
return backend.execute("nonzero", x) return backend.execute("nonzero", x)

File diff suppressed because it is too large Load Diff