Work around JAX bug.
This commit is contained in:
parent
3a4b682b00
commit
b804a5b608
@ -2778,6 +2778,7 @@ def square(x):
|
|||||||
|
|
||||||
class Sqrt(Operation):
|
class Sqrt(Operation):
|
||||||
def call(self, x):
|
def call(self, x):
|
||||||
|
x = backend.convert_to_tensor(x)
|
||||||
return backend.execute("sqrt", x)
|
return backend.execute("sqrt", x)
|
||||||
|
|
||||||
def compute_output_spec(self, x):
|
def compute_output_spec(self, x):
|
||||||
@ -2787,6 +2788,7 @@ class Sqrt(Operation):
|
|||||||
def sqrt(x):
|
def sqrt(x):
|
||||||
if any_symbolic_tensors((x,)):
|
if any_symbolic_tensors((x,)):
|
||||||
return Sqrt().symbolic_call(x)
|
return Sqrt().symbolic_call(x)
|
||||||
|
x = backend.convert_to_tensor(x)
|
||||||
return backend.execute("sqrt", x)
|
return backend.execute("sqrt", x)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user