Fix vectorized map for numpy backend

This commit is contained in:
Francois Chollet 2023-11-02 12:42:28 -07:00
parent 2ff6f13c01
commit 11cbb29b30

@ -61,15 +61,11 @@ def cond(pred, true_fn, false_fn):
return false_fn()
def vectorized_map(function, elements):
if len(elements) == 1:
return function(elements)
else:
batch_size = elements[0].shape[0]
output_store = list()
for index in range(batch_size):
output_store.append(function([x[index] for x in elements]))
return np.stack(output_store)
def vectorized_map(function, x):
output_store = []
for element in x:
output_store.append(function(element))
return np.stack(output_store)
# Shape / dtype inference util