Add vectorized_map and scatter to torch backend (#184)
This commit is contained in:
parent
b4de088ee8
commit
04436b0da6
@ -87,7 +87,7 @@ class Variable(KerasVariable):
|
|||||||
|
|
||||||
def convert_to_tensor(x, dtype=None):
|
def convert_to_tensor(x, dtype=None):
|
||||||
# TODO: Need to address device placement arg of `as_tensor`
|
# TODO: Need to address device placement arg of `as_tensor`
|
||||||
dtype = to_torch_dtype(dtype)
|
dtype = to_torch_dtype(dtype or x.dtype)
|
||||||
if isinstance(x, Variable):
|
if isinstance(x, Variable):
|
||||||
if dtype and dtype != x.dtype:
|
if dtype and dtype != x.dtype:
|
||||||
return x.value.to(dtype)
|
return x.value.to(dtype)
|
||||||
@ -126,8 +126,20 @@ def cond(pred, true_fn, false_fn):
|
|||||||
|
|
||||||
|
|
||||||
def vectorized_map(function, elements):
|
def vectorized_map(function, elements):
|
||||||
raise NotImplementedError
|
return torch.vmap(function)(elements)
|
||||||
|
|
||||||
|
|
||||||
def scatter(*args, **kwargs):
|
def scatter(indices, values, shape):
|
||||||
raise NotImplementedError
|
indices = convert_to_tensor(indices)
|
||||||
|
values = convert_to_tensor(values)
|
||||||
|
zeros = torch.zeros(shape, dtype=values.dtype)
|
||||||
|
|
||||||
|
index_length = indices.shape[-1]
|
||||||
|
value_shape = shape[index_length:]
|
||||||
|
indices = torch.reshape(indices, [-1, index_length])
|
||||||
|
values = torch.reshape(values, [-1] + list(value_shape))
|
||||||
|
|
||||||
|
for i in range(indices.shape[0]):
|
||||||
|
index = indices[i]
|
||||||
|
zeros[tuple(index)] += values[i]
|
||||||
|
return zeros
|
||||||
|
Loading…
Reference in New Issue
Block a user