Publicly expose vectorized_map.

This commit is contained in:
Francois Chollet 2023-11-02 11:57:10 -07:00
parent a6d1fe1c71
commit 2ff6f13c01
2 changed files with 32 additions and 0 deletions

@ -592,3 +592,20 @@ def cond(pred, true_fn, false_fn):
The output of either `true_fn` or `false_fn` depending on pred.
"""
return Cond()(pred, true_fn, false_fn)
# TODO: also create an Op subclass VectorizedMap.
@keras_export("keras.ops.vectorized_map")
def vectorized_map(function, x):
"""Parallel map of `function` on axis 0 of tensor `x`.
Schematically, `vectorized_map` implements the following:
```python
def vectorized_map(function, x)
outputs = []
for element in x:
outputs.append(function(element))
return stack(outputs)
"""
return backend.core.vectorized_map(function, x)

@ -387,3 +387,18 @@ class CoreOpsCorrectnessTest(testing.TestCase):
self.assertEqual("float16", y.dtype)
self.assertEqual(x.shape, y.shape)
self.assertTrue(hasattr(y, "_keras_history"))
def test_vectorized_map(self):
def fn(x):
return x + 1
output = ops.vectorized_map(fn, ops.zeros((2, 3), dtype="float32"))
self.assertAllClose(backend.convert_to_numpy(output), np.ones((2, 3)))
def fn(x):
return ops.stack([x, x])
output = ops.vectorized_map(fn, ops.zeros((2, 3), dtype="float32"))
self.assertAllClose(
backend.convert_to_numpy(output), np.zeros((2, 2, 3))
)