2023-04-09 19:21:45 +00:00
|
|
|
from tensorflow import nest
|
2023-04-12 18:31:58 +00:00
|
|
|
|
2023-04-09 19:21:45 +00:00
|
|
|
from keras_core.backend import KerasTensor
|
|
|
|
|
|
|
|
|
|
|
|
class SymbolicArguments:
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
# TODO: validation
|
|
|
|
self.args = nest.map_structure(lambda x: x, args)
|
|
|
|
self.kwargs = nest.map_structure(lambda x: x, kwargs)
|
|
|
|
self._flat_arguments = nest.flatten((self.args, self.kwargs))
|
|
|
|
|
|
|
|
# Used to avoid expensive `nest` operations in the most common case.
|
|
|
|
if (
|
|
|
|
not self.kwargs
|
|
|
|
and len(self.args) == 1
|
|
|
|
and isinstance(self.args[0], KerasTensor)
|
|
|
|
):
|
|
|
|
self._single_positional_tensor = self.args[0]
|
|
|
|
else:
|
|
|
|
self._single_positional_tensor = None
|
|
|
|
|
|
|
|
self.keras_tensors = []
|
|
|
|
for arg in self._flat_arguments:
|
|
|
|
if isinstance(arg, KerasTensor):
|
|
|
|
self.keras_tensors.append(arg)
|
|
|
|
|
|
|
|
def convert(self, conversion_fn):
|
|
|
|
args = nest.map_structure(conversion_fn, self.args)
|
|
|
|
kwargs = nest.map_structure(conversion_fn, self.kwargs)
|
|
|
|
return args, kwargs
|
|
|
|
|
|
|
|
def fill_in(self, tensor_dict):
|
|
|
|
"""Maps KerasTensors to computed values using `tensor_dict`.
|
|
|
|
|
|
|
|
`tensor_dict` maps `KerasTensor` instances to their current values.
|
|
|
|
"""
|
|
|
|
if self._single_positional_tensor is not None:
|
|
|
|
# Performance optimization for most common case.
|
|
|
|
# Approx. 70x faster.
|
|
|
|
return (tensor_dict[self._single_positional_tensor],), {}
|
|
|
|
|
|
|
|
def switch_fn(x):
|
|
|
|
val = tensor_dict.get(x, None)
|
|
|
|
if val is not None:
|
|
|
|
return val
|
|
|
|
return x
|
|
|
|
|
|
|
|
return self.convert(switch_fn)
|