Publicly expose vectorized_map
.
This commit is contained in:
parent
a6d1fe1c71
commit
2ff6f13c01
@ -592,3 +592,20 @@ def cond(pred, true_fn, false_fn):
|
||||
The output of either `true_fn` or `false_fn` depending on pred.
|
||||
"""
|
||||
return Cond()(pred, true_fn, false_fn)
|
||||
|
||||
|
||||
# TODO: also create an Op subclass VectorizedMap.
|
||||
@keras_export("keras.ops.vectorized_map")
|
||||
def vectorized_map(function, x):
|
||||
"""Parallel map of `function` on axis 0 of tensor `x`.
|
||||
|
||||
Schematically, `vectorized_map` implements the following:
|
||||
|
||||
```python
|
||||
def vectorized_map(function, x)
|
||||
outputs = []
|
||||
for element in x:
|
||||
outputs.append(function(element))
|
||||
return stack(outputs)
|
||||
"""
|
||||
return backend.core.vectorized_map(function, x)
|
||||
|
@ -387,3 +387,18 @@ class CoreOpsCorrectnessTest(testing.TestCase):
|
||||
self.assertEqual("float16", y.dtype)
|
||||
self.assertEqual(x.shape, y.shape)
|
||||
self.assertTrue(hasattr(y, "_keras_history"))
|
||||
|
||||
def test_vectorized_map(self):
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
output = ops.vectorized_map(fn, ops.zeros((2, 3), dtype="float32"))
|
||||
self.assertAllClose(backend.convert_to_numpy(output), np.ones((2, 3)))
|
||||
|
||||
def fn(x):
|
||||
return ops.stack([x, x])
|
||||
|
||||
output = ops.vectorized_map(fn, ops.zeros((2, 3), dtype="float32"))
|
||||
self.assertAllClose(
|
||||
backend.convert_to_numpy(output), np.zeros((2, 2, 3))
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user