Fix vectorized map for numpy backend
This commit is contained in:
parent
2ff6f13c01
commit
11cbb29b30
@ -61,15 +61,11 @@ def cond(pred, true_fn, false_fn):
|
||||
return false_fn()
|
||||
|
||||
|
||||
def vectorized_map(function, elements):
|
||||
if len(elements) == 1:
|
||||
return function(elements)
|
||||
else:
|
||||
batch_size = elements[0].shape[0]
|
||||
output_store = list()
|
||||
for index in range(batch_size):
|
||||
output_store.append(function([x[index] for x in elements]))
|
||||
return np.stack(output_store)
|
||||
def vectorized_map(function, x):
|
||||
output_store = []
|
||||
for element in x:
|
||||
output_store.append(function(element))
|
||||
return np.stack(output_store)
|
||||
|
||||
|
||||
# Shape / dtype inference util
|
||||
|
Loading…
Reference in New Issue
Block a user