Avoid some tree operations (#620)

* avoid some tree operations

* addressing comments

---------

Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
This commit is contained in:
Haifeng Jin 2023-07-26 15:38:45 -07:00 committed by Francois Chollet
parent 7344c66228
commit 30121f4b5f
2 changed files with 8 additions and 5 deletions

@ -624,9 +624,6 @@ class Layer(BackendLayer, Operation):
# 1. Convert any array arguments to tensors of correct dtype. # 1. Convert any array arguments to tensors of correct dtype.
def maybe_convert(x): def maybe_convert(x):
if backend.is_tensor(x): if backend.is_tensor(x):
# Handle Torch device placement.
if backend.backend() == "torch":
x = backend.convert_to_tensor(x)
if ( if (
self.autocast self.autocast
and backend.is_float_dtype(x.dtype) and backend.is_float_dtype(x.dtype)
@ -646,7 +643,13 @@ class Layer(BackendLayer, Operation):
return backend.convert_to_tensor(x, dtype=self.compute_dtype) return backend.convert_to_tensor(x, dtype=self.compute_dtype)
return x return x
if self._convert_input_args: # Used to avoid expensive `tree` operations in the most common case.
if (
kwargs
or len(args) != 1
or not backend.is_tensor(args[0])
or backend.standardize_dtype(args[0].dtype) != self.compute_dtype
) and self._convert_input_args:
args = tree.map_structure(maybe_convert, args) args = tree.map_structure(maybe_convert, args)
kwargs = tree.map_structure(maybe_convert, kwargs) kwargs = tree.map_structure(maybe_convert, kwargs)

@ -9,7 +9,7 @@ class SymbolicArguments:
self.kwargs = tree.map_structure(lambda x: x, kwargs) self.kwargs = tree.map_structure(lambda x: x, kwargs)
self._flat_arguments = tree.flatten((self.args, self.kwargs)) self._flat_arguments = tree.flatten((self.args, self.kwargs))
# Used to avoid expensive `nest` operations in the most common case. # Used to avoid expensive `tree` operations in the most common case.
if ( if (
not self.kwargs not self.kwargs
and len(self.args) == 1 and len(self.args) == 1