Add vectorized_map and scatter to torch backend (#184)

This commit is contained in:
Chen Qian 2023-05-17 18:41:12 -07:00 committed by Francois Chollet
parent b4de088ee8
commit 04436b0da6

@ -87,7 +87,7 @@ class Variable(KerasVariable):
def convert_to_tensor(x, dtype=None): def convert_to_tensor(x, dtype=None):
# TODO: Need to address device placement arg of `as_tensor` # TODO: Need to address device placement arg of `as_tensor`
dtype = to_torch_dtype(dtype) dtype = to_torch_dtype(dtype or x.dtype)
if isinstance(x, Variable): if isinstance(x, Variable):
if dtype and dtype != x.dtype: if dtype and dtype != x.dtype:
return x.value.to(dtype) return x.value.to(dtype)
@ -126,8 +126,20 @@ def cond(pred, true_fn, false_fn):
def vectorized_map(function, elements): def vectorized_map(function, elements):
raise NotImplementedError return torch.vmap(function)(elements)
def scatter(*args, **kwargs): def scatter(indices, values, shape):
raise NotImplementedError indices = convert_to_tensor(indices)
values = convert_to_tensor(values)
zeros = torch.zeros(shape, dtype=values.dtype)
index_length = indices.shape[-1]
value_shape = shape[index_length:]
indices = torch.reshape(indices, [-1, index_length])
values = torch.reshape(values, [-1] + list(value_shape))
for i in range(indices.shape[0]):
index = indices[i]
zeros[tuple(index)] += values[i]
return zeros