vtkm::worklet::Invoker now supports scatter types

Fixes #297
This commit is contained in:
Robert Maynard 2019-05-15 10:47:20 -04:00
parent afc3f5304e
commit 18a0cd35b0
5 changed files with 59 additions and 14 deletions

@ -0,0 +1,21 @@
# `vtkm::worklet::Invoker` now able to worklets that have non-default scatter type
This change allows the `Invoker` class to support launching worklets that require
a custom scatter operation. This is done by providing the scatter as the second
argument when launch a worklet with the `()` operator.
The following example shows a scatter being provided with a worklet launch.
```cpp
struct CheckTopology : vtkm::worklet::WorkletMapPointToCell
{
using ControlSignature = void(CellSetIn cellset, FieldOutCell);
using ExecutionSignature = _2(FromIndices);
using ScatterType = vtkm::worklet::ScatterPermutation<>;
...
};
vtkm::worklet::Ivoker invoke;
invoke( CheckTopology{}, vtkm::worklet::ScatterPermutation{}, cellset, result );
```

@ -52,18 +52,44 @@ struct Invoker
{
}
/// Launch the worklet that is provided as the first parameter. The additional
/// parameters are the ControlSignature arguments for the worklet.
/// Launch the worklet that is provided as the first parameter.
/// Optional second parameter is the scatter type associated with the worklet.
/// Any additional parameters are the ControlSignature arguments for the worklet.
///
template <typename Worklet, typename... Args>
inline void operator()(Worklet&& worklet, Args&&... args) const
template <typename Worklet,
typename T,
typename... Args,
typename std::enable_if<
std::is_base_of<internal::ScatterBase, internal::detail::remove_cvref<T>>::value,
int>::type* = nullptr>
inline void operator()(Worklet&& worklet, T&& scatter, Args&&... args) const
{
using WorkletType = typename std::decay<Worklet>::type;
using WorkletType = internal::detail::remove_cvref<Worklet>;
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
DispatcherType dispatcher(worklet, scatter);
dispatcher.SetDevice(this->DeviceId);
dispatcher.Invoke(std::forward<Args>(args)...);
}
/// Launch the worklet that is provided as the first parameter.
/// Optional second parameter is the scatter type associated with the worklet.
/// Any additional parameters are the ControlSignature arguments for the worklet.
///
template <typename Worklet,
typename T,
typename... Args,
typename std::enable_if<
!std::is_base_of<internal::ScatterBase, internal::detail::remove_cvref<T>>::value,
int>::type* = nullptr>
inline void operator()(Worklet&& worklet, T&& t, Args&&... args) const
{
using WorkletType = internal::detail::remove_cvref<Worklet>;
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
DispatcherType dispatcher(worklet);
dispatcher.SetDevice(this->DeviceId);
dispatcher.Invoke(std::forward<Args>(args)...);
dispatcher.Invoke(std::forward<T>(t), std::forward<Args>(args)...);
}
/// Get the device adapter that this Invoker is bound too

@ -76,7 +76,7 @@
#include <vtkm/cont/ArrayHandleIndex.h>
#include <vtkm/cont/ArrayHandleReverse.h>
#include <vtkm/cont/ArrayHandleTransform.h>
#include <vtkm/worklet/DispatcherMapField.h>
#include <vtkm/worklet/Invoker.h>
#include <vtkm/worklet/ScatterCounting.h>
#include <vtkm/BinaryPredicates.h>

@ -316,16 +316,15 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
// Setup the ScatterCounting worklets needed to expand the ReduceByKeyResults
vtkm::worklet::ScatterCounting scatter(particlesPerHalo);
vtkm::worklet::DispatcherMapField<ScatterWorklet<vtkm::Id>> scatterWorkletIdDispatcher(scatter);
vtkm::worklet::DispatcherMapField<ScatterWorklet<T>> scatterWorkletDispatcher(scatter);
vtkm::worklet::Invoker invoke;
// Calculate the minimum particle index per halo id and scatter
DeviceAlgorithm::ScanExclusive(particlesPerHalo, tempI);
scatterWorkletIdDispatcher.Invoke(tempI, minParticle);
invoke(ScatterWorklet<vtkm::Id>{}, scatter, tempI, minParticle);
// Calculate the maximum particle index per halo id and scatter
DeviceAlgorithm::ScanInclusive(particlesPerHalo, tempI);
scatterWorkletIdDispatcher.Invoke(tempI, maxParticle);
invoke(ScatterWorklet<vtkm::Id>{}, scatter, tempI, maxParticle);
using IdArrayType = vtkm::cont::ArrayHandle<vtkm::Id>;
vtkm::cont::ArrayHandleTransform<IdArrayType, ScaleBiasFunctor<vtkm::Id>> scaleBias =
@ -354,7 +353,7 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
// Find minimum potential for all particles in a halo and scatter
DeviceAlgorithm::ReduceByKey(haloId, potential, uniqueHaloIds, tempT, vtkm::Minimum());
scatterWorkletDispatcher.Invoke(tempT, minPotential);
invoke(ScatterWorklet<T>{}, scatter, tempT, minPotential);
#ifdef DEBUG_PRINT
DebugPrint("potential", potential);
DebugPrint("minPotential", minPotential);
@ -371,7 +370,7 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
vtkm::cont::ArrayHandle<vtkm::Id> minIndx;
minIndx.Allocate(nParticles);
DeviceAlgorithm::ReduceByKey(haloId, mbpId, uniqueHaloIds, minIndx, vtkm::Maximum());
scatterWorkletIdDispatcher.Invoke(minIndx, mbpId);
invoke(ScatterWorklet<vtkm::Id>{}, scatter, minIndx, mbpId);
// Resort particle ids and mbpId to starting order
vtkm::cont::ArrayHandle<vtkm::Id> savePartId;

@ -158,7 +158,6 @@ using remove_pointer_and_decay = typename std::remove_pointer<typename std::deca
template <typename T>
using remove_cvref = typename std::remove_cv<typename std::remove_reference<T>::type>::type;
// Is designed as a brigand fold operation.
template <typename Type, typename State>
struct DetermineIfHasDynamicParameter