Avoid some tree operations (#620)
* avoid some tree operations * addressing comments --------- Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
This commit is contained in:
parent
7344c66228
commit
30121f4b5f
@ -624,9 +624,6 @@ class Layer(BackendLayer, Operation):
|
||||
# 1. Convert any array arguments to tensors of correct dtype.
|
||||
def maybe_convert(x):
|
||||
if backend.is_tensor(x):
|
||||
# Handle Torch device placement.
|
||||
if backend.backend() == "torch":
|
||||
x = backend.convert_to_tensor(x)
|
||||
if (
|
||||
self.autocast
|
||||
and backend.is_float_dtype(x.dtype)
|
||||
@ -646,7 +643,13 @@ class Layer(BackendLayer, Operation):
|
||||
return backend.convert_to_tensor(x, dtype=self.compute_dtype)
|
||||
return x
|
||||
|
||||
if self._convert_input_args:
|
||||
# Used to avoid expensive `tree` operations in the most common case.
|
||||
if (
|
||||
kwargs
|
||||
or len(args) != 1
|
||||
or not backend.is_tensor(args[0])
|
||||
or backend.standardize_dtype(args[0].dtype) != self.compute_dtype
|
||||
) and self._convert_input_args:
|
||||
args = tree.map_structure(maybe_convert, args)
|
||||
kwargs = tree.map_structure(maybe_convert, kwargs)
|
||||
|
||||
|
@ -9,7 +9,7 @@ class SymbolicArguments:
|
||||
self.kwargs = tree.map_structure(lambda x: x, kwargs)
|
||||
self._flat_arguments = tree.flatten((self.args, self.kwargs))
|
||||
|
||||
# Used to avoid expensive `nest` operations in the most common case.
|
||||
# Used to avoid expensive `tree` operations in the most common case.
|
||||
if (
|
||||
not self.kwargs
|
||||
and len(self.args) == 1
|
||||
|
Loading…
Reference in New Issue
Block a user