keras/keras_core/operations/symbolic_arguments.py
Chen Qian eabdb87f9f Add some numpy ops (#1)
* Add numpy ops (initial batch) and some config

* Add unit test

* fix call

* Revert "fix call"

This reverts commit 6748ad183029ff4b97317b77ceed8661916bb9a0.

* full unit test coverage

* fix setup.py
2023-04-12 11:31:58 -07:00

50 lines
1.6 KiB
Python

from tensorflow import nest
from keras_core.backend import KerasTensor
class SymbolicArguments:
def __init__(self, *args, **kwargs):
# TODO: validation
self.args = nest.map_structure(lambda x: x, args)
self.kwargs = nest.map_structure(lambda x: x, kwargs)
self._flat_arguments = nest.flatten((self.args, self.kwargs))
# Used to avoid expensive `nest` operations in the most common case.
if (
not self.kwargs
and len(self.args) == 1
and isinstance(self.args[0], KerasTensor)
):
self._single_positional_tensor = self.args[0]
else:
self._single_positional_tensor = None
self.keras_tensors = []
for arg in self._flat_arguments:
if isinstance(arg, KerasTensor):
self.keras_tensors.append(arg)
def convert(self, conversion_fn):
args = nest.map_structure(conversion_fn, self.args)
kwargs = nest.map_structure(conversion_fn, self.kwargs)
return args, kwargs
def fill_in(self, tensor_dict):
"""Maps KerasTensors to computed values using `tensor_dict`.
`tensor_dict` maps `KerasTensor` instances to their current values.
"""
if self._single_positional_tensor is not None:
# Performance optimization for most common case.
# Approx. 70x faster.
return (tensor_dict[self._single_positional_tensor],), {}
def switch_fn(x):
val = tensor_dict.get(x, None)
if val is not None:
return val
return x
return self.convert(switch_fn)