Work around JAX bug.
This commit is contained in:
parent
3a4b682b00
commit
b804a5b608
@ -2778,6 +2778,7 @@ def square(x):
|
||||
|
||||
class Sqrt(Operation):
|
||||
def call(self, x):
|
||||
x = backend.convert_to_tensor(x)
|
||||
return backend.execute("sqrt", x)
|
||||
|
||||
def compute_output_spec(self, x):
|
||||
@ -2787,6 +2788,7 @@ class Sqrt(Operation):
|
||||
def sqrt(x):
|
||||
if any_symbolic_tensors((x,)):
|
||||
return Sqrt().symbolic_call(x)
|
||||
x = backend.convert_to_tensor(x)
|
||||
return backend.execute("sqrt", x)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user