diff --git a/keras/backend/numpy/core.py b/keras/backend/numpy/core.py index 1b3e8e408..2e6ea07f0 100644 --- a/keras/backend/numpy/core.py +++ b/keras/backend/numpy/core.py @@ -61,15 +61,11 @@ def cond(pred, true_fn, false_fn): return false_fn() -def vectorized_map(function, elements): - if len(elements) == 1: - return function(elements) - else: - batch_size = elements[0].shape[0] - output_store = list() - for index in range(batch_size): - output_store.append(function([x[index] for x in elements])) - return np.stack(output_store) +def vectorized_map(function, x): + output_store = [] + for element in x: + output_store.append(function(element)) + return np.stack(output_store) # Shape / dtype inference util