Add vectorized_map and scatter to torch backend (#184)
This commit is contained in:
parent
b4de088ee8
commit
04436b0da6
@ -87,7 +87,7 @@ class Variable(KerasVariable):
|
||||
|
||||
def convert_to_tensor(x, dtype=None):
|
||||
# TODO: Need to address device placement arg of `as_tensor`
|
||||
dtype = to_torch_dtype(dtype)
|
||||
dtype = to_torch_dtype(dtype or x.dtype)
|
||||
if isinstance(x, Variable):
|
||||
if dtype and dtype != x.dtype:
|
||||
return x.value.to(dtype)
|
||||
@ -126,8 +126,20 @@ def cond(pred, true_fn, false_fn):
|
||||
|
||||
|
||||
def vectorized_map(function, elements):
|
||||
raise NotImplementedError
|
||||
return torch.vmap(function)(elements)
|
||||
|
||||
|
||||
def scatter(*args, **kwargs):
|
||||
raise NotImplementedError
|
||||
def scatter(indices, values, shape):
|
||||
indices = convert_to_tensor(indices)
|
||||
values = convert_to_tensor(values)
|
||||
zeros = torch.zeros(shape, dtype=values.dtype)
|
||||
|
||||
index_length = indices.shape[-1]
|
||||
value_shape = shape[index_length:]
|
||||
indices = torch.reshape(indices, [-1, index_length])
|
||||
values = torch.reshape(values, [-1] + list(value_shape))
|
||||
|
||||
for i in range(indices.shape[0]):
|
||||
index = indices[i]
|
||||
zeros[tuple(index)] += values[i]
|
||||
return zeros
|
||||
|
Loading…
Reference in New Issue
Block a user