From 04436b0da6fddeb177a93f15a4c6c290d016bf27 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Wed, 17 May 2023 18:41:12 -0700 Subject: [PATCH] Add vectorized_map and scatter to torch backend (#184) --- keras_core/backend/torch/core.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/keras_core/backend/torch/core.py b/keras_core/backend/torch/core.py index f2cc1eb8d..20b448eaa 100644 --- a/keras_core/backend/torch/core.py +++ b/keras_core/backend/torch/core.py @@ -87,7 +87,7 @@ class Variable(KerasVariable): def convert_to_tensor(x, dtype=None): # TODO: Need to address device placement arg of `as_tensor` - dtype = to_torch_dtype(dtype) + dtype = to_torch_dtype(dtype or x.dtype) if isinstance(x, Variable): if dtype and dtype != x.dtype: return x.value.to(dtype) @@ -126,8 +126,20 @@ def cond(pred, true_fn, false_fn): def vectorized_map(function, elements): - raise NotImplementedError + return torch.vmap(function)(elements) -def scatter(*args, **kwargs): - raise NotImplementedError +def scatter(indices, values, shape): + indices = convert_to_tensor(indices) + values = convert_to_tensor(values) + zeros = torch.zeros(shape, dtype=values.dtype) + + index_length = indices.shape[-1] + value_shape = shape[index_length:] + indices = torch.reshape(indices, [-1, index_length]) + values = torch.reshape(values, [-1] + list(value_shape)) + + for i in range(indices.shape[0]): + index = indices[i] + zeros[tuple(index)] += values[i] + return zeros