mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-09-16 17:22:55 +00:00
parent
afc3f5304e
commit
18a0cd35b0
21
docs/changelog/invoker-supports-scatter-types.md
Normal file
21
docs/changelog/invoker-supports-scatter-types.md
Normal file
@ -0,0 +1,21 @@
|
||||
# `vtkm::worklet::Invoker` now able to worklets that have non-default scatter type
|
||||
|
||||
This change allows the `Invoker` class to support launching worklets that require
|
||||
a custom scatter operation. This is done by providing the scatter as the second
|
||||
argument when launch a worklet with the `()` operator.
|
||||
|
||||
The following example shows a scatter being provided with a worklet launch.
|
||||
|
||||
```cpp
|
||||
struct CheckTopology : vtkm::worklet::WorkletMapPointToCell
|
||||
{
|
||||
using ControlSignature = void(CellSetIn cellset, FieldOutCell);
|
||||
using ExecutionSignature = _2(FromIndices);
|
||||
using ScatterType = vtkm::worklet::ScatterPermutation<>;
|
||||
...
|
||||
};
|
||||
|
||||
|
||||
vtkm::worklet::Ivoker invoke;
|
||||
invoke( CheckTopology{}, vtkm::worklet::ScatterPermutation{}, cellset, result );
|
||||
```
|
@ -52,18 +52,44 @@ struct Invoker
|
||||
{
|
||||
}
|
||||
|
||||
/// Launch the worklet that is provided as the first parameter. The additional
|
||||
/// parameters are the ControlSignature arguments for the worklet.
|
||||
/// Launch the worklet that is provided as the first parameter.
|
||||
/// Optional second parameter is the scatter type associated with the worklet.
|
||||
/// Any additional parameters are the ControlSignature arguments for the worklet.
|
||||
///
|
||||
template <typename Worklet, typename... Args>
|
||||
inline void operator()(Worklet&& worklet, Args&&... args) const
|
||||
template <typename Worklet,
|
||||
typename T,
|
||||
typename... Args,
|
||||
typename std::enable_if<
|
||||
std::is_base_of<internal::ScatterBase, internal::detail::remove_cvref<T>>::value,
|
||||
int>::type* = nullptr>
|
||||
inline void operator()(Worklet&& worklet, T&& scatter, Args&&... args) const
|
||||
{
|
||||
using WorkletType = typename std::decay<Worklet>::type;
|
||||
using WorkletType = internal::detail::remove_cvref<Worklet>;
|
||||
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
|
||||
|
||||
DispatcherType dispatcher(worklet, scatter);
|
||||
dispatcher.SetDevice(this->DeviceId);
|
||||
dispatcher.Invoke(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/// Launch the worklet that is provided as the first parameter.
|
||||
/// Optional second parameter is the scatter type associated with the worklet.
|
||||
/// Any additional parameters are the ControlSignature arguments for the worklet.
|
||||
///
|
||||
template <typename Worklet,
|
||||
typename T,
|
||||
typename... Args,
|
||||
typename std::enable_if<
|
||||
!std::is_base_of<internal::ScatterBase, internal::detail::remove_cvref<T>>::value,
|
||||
int>::type* = nullptr>
|
||||
inline void operator()(Worklet&& worklet, T&& t, Args&&... args) const
|
||||
{
|
||||
using WorkletType = internal::detail::remove_cvref<Worklet>;
|
||||
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
|
||||
|
||||
DispatcherType dispatcher(worklet);
|
||||
dispatcher.SetDevice(this->DeviceId);
|
||||
dispatcher.Invoke(std::forward<Args>(args)...);
|
||||
dispatcher.Invoke(std::forward<T>(t), std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/// Get the device adapter that this Invoker is bound too
|
||||
|
@ -76,7 +76,7 @@
|
||||
#include <vtkm/cont/ArrayHandleIndex.h>
|
||||
#include <vtkm/cont/ArrayHandleReverse.h>
|
||||
#include <vtkm/cont/ArrayHandleTransform.h>
|
||||
#include <vtkm/worklet/DispatcherMapField.h>
|
||||
#include <vtkm/worklet/Invoker.h>
|
||||
#include <vtkm/worklet/ScatterCounting.h>
|
||||
|
||||
#include <vtkm/BinaryPredicates.h>
|
||||
|
@ -316,16 +316,15 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
|
||||
|
||||
// Setup the ScatterCounting worklets needed to expand the ReduceByKeyResults
|
||||
vtkm::worklet::ScatterCounting scatter(particlesPerHalo);
|
||||
vtkm::worklet::DispatcherMapField<ScatterWorklet<vtkm::Id>> scatterWorkletIdDispatcher(scatter);
|
||||
vtkm::worklet::DispatcherMapField<ScatterWorklet<T>> scatterWorkletDispatcher(scatter);
|
||||
vtkm::worklet::Invoker invoke;
|
||||
|
||||
// Calculate the minimum particle index per halo id and scatter
|
||||
DeviceAlgorithm::ScanExclusive(particlesPerHalo, tempI);
|
||||
scatterWorkletIdDispatcher.Invoke(tempI, minParticle);
|
||||
invoke(ScatterWorklet<vtkm::Id>{}, scatter, tempI, minParticle);
|
||||
|
||||
// Calculate the maximum particle index per halo id and scatter
|
||||
DeviceAlgorithm::ScanInclusive(particlesPerHalo, tempI);
|
||||
scatterWorkletIdDispatcher.Invoke(tempI, maxParticle);
|
||||
invoke(ScatterWorklet<vtkm::Id>{}, scatter, tempI, maxParticle);
|
||||
|
||||
using IdArrayType = vtkm::cont::ArrayHandle<vtkm::Id>;
|
||||
vtkm::cont::ArrayHandleTransform<IdArrayType, ScaleBiasFunctor<vtkm::Id>> scaleBias =
|
||||
@ -354,7 +353,7 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
|
||||
|
||||
// Find minimum potential for all particles in a halo and scatter
|
||||
DeviceAlgorithm::ReduceByKey(haloId, potential, uniqueHaloIds, tempT, vtkm::Minimum());
|
||||
scatterWorkletDispatcher.Invoke(tempT, minPotential);
|
||||
invoke(ScatterWorklet<T>{}, scatter, tempT, minPotential);
|
||||
#ifdef DEBUG_PRINT
|
||||
DebugPrint("potential", potential);
|
||||
DebugPrint("minPotential", minPotential);
|
||||
@ -371,7 +370,7 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
|
||||
vtkm::cont::ArrayHandle<vtkm::Id> minIndx;
|
||||
minIndx.Allocate(nParticles);
|
||||
DeviceAlgorithm::ReduceByKey(haloId, mbpId, uniqueHaloIds, minIndx, vtkm::Maximum());
|
||||
scatterWorkletIdDispatcher.Invoke(minIndx, mbpId);
|
||||
invoke(ScatterWorklet<vtkm::Id>{}, scatter, minIndx, mbpId);
|
||||
|
||||
// Resort particle ids and mbpId to starting order
|
||||
vtkm::cont::ArrayHandle<vtkm::Id> savePartId;
|
||||
|
@ -158,7 +158,6 @@ using remove_pointer_and_decay = typename std::remove_pointer<typename std::deca
|
||||
template <typename T>
|
||||
using remove_cvref = typename std::remove_cv<typename std::remove_reference<T>::type>::type;
|
||||
|
||||
|
||||
// Is designed as a brigand fold operation.
|
||||
template <typename Type, typename State>
|
||||
struct DetermineIfHasDynamicParameter
|
||||
|
Loading…
Reference in New Issue
Block a user