Work around JAX bug.

This commit is contained in:
Francois Chollet 2023-04-26 12:15:14 -07:00
parent 3a4b682b00
commit b804a5b608

@ -2778,6 +2778,7 @@ def square(x):
class Sqrt(Operation):
def call(self, x):
x = backend.convert_to_tensor(x)
return backend.execute("sqrt", x)
def compute_output_spec(self, x):
@ -2787,6 +2788,7 @@ class Sqrt(Operation):
def sqrt(x):
if any_symbolic_tensors((x,)):
return Sqrt().symbolic_call(x)
x = backend.convert_to_tensor(x)
return backend.execute("sqrt", x)