From 11cbb29b30a52ee6ea0087eebf4837ce016b97d7 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Thu, 2 Nov 2023 12:42:28 -0700 Subject: [PATCH] Fix vectorized map for numpy backend --- keras/backend/numpy/core.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/keras/backend/numpy/core.py b/keras/backend/numpy/core.py index 1b3e8e408..2e6ea07f0 100644 --- a/keras/backend/numpy/core.py +++ b/keras/backend/numpy/core.py @@ -61,15 +61,11 @@ def cond(pred, true_fn, false_fn): return false_fn() -def vectorized_map(function, elements): - if len(elements) == 1: - return function(elements) - else: - batch_size = elements[0].shape[0] - output_store = list() - for index in range(batch_size): - output_store.append(function([x[index] for x in elements])) - return np.stack(output_store) +def vectorized_map(function, x): + output_store = [] + for element in x: + output_store.append(function(element)) + return np.stack(output_store) # Shape / dtype inference util