Merge topic 'invoker_support_scatter'

18a0cd35b vtkm::worklet::Invoker now supports scatter types
afc3f5304 Remove unneeded ScatterType as it was the default
a1ea509f8 All scatter types now inherit from a common base
77378993f DispatcherBase: Simplify remove_cvref and remove_pointer_and_decay.

Acked-by: Kitware Robot <kwrobot@kitware.com>
Acked-by: Kenneth Moreland <kmorel@sandia.gov>
Merge-request: !1673
This commit is contained in:
Robert Maynard 2019-05-15 18:16:55 +00:00 committed by Kitware Robot
commit be5b51fb01
12 changed files with 112 additions and 30 deletions

@ -0,0 +1,21 @@
# `vtkm::worklet::Invoker` now able to worklets that have non-default scatter type
This change allows the `Invoker` class to support launching worklets that require
a custom scatter operation. This is done by providing the scatter as the second
argument when launch a worklet with the `()` operator.
The following example shows a scatter being provided with a worklet launch.
```cpp
struct CheckTopology : vtkm::worklet::WorkletMapPointToCell
{
using ControlSignature = void(CellSetIn cellset, FieldOutCell);
using ExecutionSignature = _2(FromIndices);
using ScatterType = vtkm::worklet::ScatterPermutation<>;
...
};
vtkm::worklet::Ivoker invoke;
invoke( CheckTopology{}, vtkm::worklet::ScatterPermutation{}, cellset, result );
```

@ -808,8 +808,6 @@ public:
using ExecutionSignature = void(_2, _3);
using ScatterType = vtkm::worklet::ScatterIdentity;
template <typename MappedValueVecType, typename MappedValueType>
VTKM_EXEC void operator()(const MappedValueVecType& toReduce, MappedValueType& centroid) const
{

@ -52,18 +52,44 @@ struct Invoker
{
}
/// Launch the worklet that is provided as the first parameter. The additional
/// parameters are the ControlSignature arguments for the worklet.
/// Launch the worklet that is provided as the first parameter.
/// Optional second parameter is the scatter type associated with the worklet.
/// Any additional parameters are the ControlSignature arguments for the worklet.
///
template <typename Worklet, typename... Args>
inline void operator()(Worklet&& worklet, Args&&... args) const
template <typename Worklet,
typename T,
typename... Args,
typename std::enable_if<
std::is_base_of<internal::ScatterBase, internal::detail::remove_cvref<T>>::value,
int>::type* = nullptr>
inline void operator()(Worklet&& worklet, T&& scatter, Args&&... args) const
{
using WorkletType = typename std::decay<Worklet>::type;
using WorkletType = internal::detail::remove_cvref<Worklet>;
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
DispatcherType dispatcher(worklet, scatter);
dispatcher.SetDevice(this->DeviceId);
dispatcher.Invoke(std::forward<Args>(args)...);
}
/// Launch the worklet that is provided as the first parameter.
/// Optional second parameter is the scatter type associated with the worklet.
/// Any additional parameters are the ControlSignature arguments for the worklet.
///
template <typename Worklet,
typename T,
typename... Args,
typename std::enable_if<
!std::is_base_of<internal::ScatterBase, internal::detail::remove_cvref<T>>::value,
int>::type* = nullptr>
inline void operator()(Worklet&& worklet, T&& t, Args&&... args) const
{
using WorkletType = internal::detail::remove_cvref<Worklet>;
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
DispatcherType dispatcher(worklet);
dispatcher.SetDevice(this->DeviceId);
dispatcher.Invoke(std::forward<Args>(args)...);
dispatcher.Invoke(std::forward<T>(t), std::forward<Args>(args)...);
}
/// Get the device adapter that this Invoker is bound too

@ -10,6 +10,7 @@
#ifndef vtk_m_worklet_ScatterCounting_h
#define vtk_m_worklet_ScatterCounting_h
#include <vtkm/worklet/internal/ScatterBase.h>
#include <vtkm/worklet/vtkm_worklet_export.h>
#include <vtkm/cont/VariantArrayHandle.h>
@ -40,7 +41,7 @@ struct ScatterCountingBuilder;
/// taken in the constructor and the index arrays are derived from that. So
/// changing the counts after the scatter is created will have no effect.
///
struct VTKM_WORKLET_EXPORT ScatterCounting
struct VTKM_WORKLET_EXPORT ScatterCounting : internal::ScatterBase
{
struct CountTypes : vtkm::ListTagBase<vtkm::Int64,
vtkm::Int32,

@ -12,6 +12,7 @@
#include <vtkm/cont/ArrayHandleConstant.h>
#include <vtkm/cont/ArrayHandleIndex.h>
#include <vtkm/worklet/internal/ScatterBase.h>
namespace vtkm
{
@ -26,7 +27,7 @@ namespace worklet
/// element generates one output element associated with it. This is the
/// default for basic maps.
///
struct ScatterIdentity
struct ScatterIdentity : internal::ScatterBase
{
using OutputToInputMapType = vtkm::cont::ArrayHandleIndex;
VTKM_CONT

@ -12,6 +12,7 @@
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleConstant.h>
#include <vtkm/worklet/internal/ScatterBase.h>
namespace vtkm
{
@ -28,7 +29,7 @@ namespace worklet
/// can be duplicates. Note that even with duplicates the VistIndex is always 0.
///
template <typename PermutationStorage = VTKM_DEFAULT_STORAGE_TAG>
class ScatterPermutation
class ScatterPermutation : public internal::ScatterBase
{
private:
using PermutationArrayHandle = vtkm::cont::ArrayHandle<vtkm::Id, PermutationStorage>;

@ -13,6 +13,7 @@
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleCounting.h>
#include <vtkm/cont/ArrayHandleImplicit.h>
#include <vtkm/worklet/internal/ScatterBase.h>
namespace vtkm
{
@ -49,7 +50,7 @@ struct FunctorDiv
/// elements are grouped by the input associated.
///
template <vtkm::IdComponent NumOutputsPerInput>
struct ScatterUniform
struct ScatterUniform : internal::ScatterBase
{
VTKM_CONT ScatterUniform() = default;

@ -76,7 +76,7 @@
#include <vtkm/cont/ArrayHandleIndex.h>
#include <vtkm/cont/ArrayHandleReverse.h>
#include <vtkm/cont/ArrayHandleTransform.h>
#include <vtkm/worklet/DispatcherMapField.h>
#include <vtkm/worklet/Invoker.h>
#include <vtkm/worklet/ScatterCounting.h>
#include <vtkm/BinaryPredicates.h>

@ -316,16 +316,15 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
// Setup the ScatterCounting worklets needed to expand the ReduceByKeyResults
vtkm::worklet::ScatterCounting scatter(particlesPerHalo);
vtkm::worklet::DispatcherMapField<ScatterWorklet<vtkm::Id>> scatterWorkletIdDispatcher(scatter);
vtkm::worklet::DispatcherMapField<ScatterWorklet<T>> scatterWorkletDispatcher(scatter);
vtkm::worklet::Invoker invoke;
// Calculate the minimum particle index per halo id and scatter
DeviceAlgorithm::ScanExclusive(particlesPerHalo, tempI);
scatterWorkletIdDispatcher.Invoke(tempI, minParticle);
invoke(ScatterWorklet<vtkm::Id>{}, scatter, tempI, minParticle);
// Calculate the maximum particle index per halo id and scatter
DeviceAlgorithm::ScanInclusive(particlesPerHalo, tempI);
scatterWorkletIdDispatcher.Invoke(tempI, maxParticle);
invoke(ScatterWorklet<vtkm::Id>{}, scatter, tempI, maxParticle);
using IdArrayType = vtkm::cont::ArrayHandle<vtkm::Id>;
vtkm::cont::ArrayHandleTransform<IdArrayType, ScaleBiasFunctor<vtkm::Id>> scaleBias =
@ -354,7 +353,7 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
// Find minimum potential for all particles in a halo and scatter
DeviceAlgorithm::ReduceByKey(haloId, potential, uniqueHaloIds, tempT, vtkm::Minimum());
scatterWorkletDispatcher.Invoke(tempT, minPotential);
invoke(ScatterWorklet<T>{}, scatter, tempT, minPotential);
#ifdef DEBUG_PRINT
DebugPrint("potential", potential);
DebugPrint("minPotential", minPotential);
@ -371,7 +370,7 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
vtkm::cont::ArrayHandle<vtkm::Id> minIndx;
minIndx.Allocate(nParticles);
DeviceAlgorithm::ReduceByKey(haloId, mbpId, uniqueHaloIds, minIndx, vtkm::Maximum());
scatterWorkletIdDispatcher.Invoke(minIndx, mbpId);
invoke(ScatterWorklet<vtkm::Id>{}, scatter, minIndx, mbpId);
// Resort particle ids and mbpId to starting order
vtkm::cont::ArrayHandle<vtkm::Id> savePartId;

@ -10,6 +10,7 @@
set(headers
DispatcherBase.h
ScatterBase.h
TriangulateTables.h
WorkletBase.h
)

@ -153,15 +153,16 @@ struct ReportValueOnError<Value, true> : std::true_type
};
template <typename T>
struct remove_pointer_and_decay : std::remove_pointer<typename std::decay<T>::type>
{
};
using remove_pointer_and_decay = typename std::remove_pointer<typename std::decay<T>::type>::type;
template <typename T>
using remove_cvref = typename std::remove_cv<typename std::remove_reference<T>::type>::type;
// Is designed as a brigand fold operation.
template <typename Type, typename State>
struct DetermineIfHasDynamicParameter
{
using T = typename std::remove_pointer<Type>::type;
using T = remove_pointer_and_decay<Type>;
using DynamicTag = typename vtkm::cont::internal::DynamicTransformTraits<T>::DynamicTag;
using isDynamic =
typename std::is_same<DynamicTag, vtkm::cont::internal::DynamicTransformTagCastAndCall>::type;
@ -314,7 +315,7 @@ struct DispatcherBaseTransportFunctor
{
using TransportTag =
typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
using T = typename remove_pointer_and_decay<ControlParameter>::type;
using T = remove_pointer_and_decay<ControlParameter>;
using TransportType = typename vtkm::cont::arg::Transport<TransportTag, T, Device>;
using type = typename TransportType::ExecObjectType;
};
@ -326,7 +327,7 @@ struct DispatcherBaseTransportFunctor
{
using TransportTag =
typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
using T = typename remove_pointer_and_decay<ControlParameter>::type;
using T = remove_pointer_and_decay<ControlParameter>;
vtkm::cont::arg::Transport<TransportTag, T, Device> transport;
not_nullptr(invokeData, Index);
@ -412,7 +413,7 @@ struct for_each_dynamic_arg
void operator()(const Trampoline& trampoline, ContParams&& sig, T&& t, Args&&... args) const
{
//Determine that state of T when it is either a `cons&` or a `* const&`
using Type = typename std::remove_pointer<typename std::decay<T>::type>::type;
using Type = remove_pointer_and_decay<T>;
using tag = typename vtkm::cont::internal::DynamicTransformTraits<Type>::DynamicTag;
//convert the first item to a known type
convert_arg<LeftToProcess>(
@ -494,7 +495,7 @@ private:
VTKM_CONT void StartInvoke(Args&&... args) const
{
using ParameterInterface =
vtkm::internal::FunctionInterface<void(typename std::decay<Args>::type...)>;
vtkm::internal::FunctionInterface<void(detail::remove_cvref<Args>...)>;
VTKM_STATIC_ASSERT_MSG(ParameterInterface::ARITY == NUM_INVOKE_PARAMS,
"Dispatcher Invoke called with wrong number of arguments.");
@ -540,7 +541,7 @@ private:
VTKM_CONT void StartInvokeDynamic(std::false_type, Args&&... args) const
{
using ParameterInterface =
vtkm::internal::FunctionInterface<void(typename std::decay<Args>::type...)>;
vtkm::internal::FunctionInterface<void(detail::remove_cvref<Args>...)>;
//Nothing requires a conversion from dynamic to static types, so
//next we need to verify that each argument's type is correct. If not
@ -561,8 +562,7 @@ private:
static_assert(isAllValid::value == expectedLen::value,
"All arguments failed the TypeCheck pass");
auto fi =
vtkm::internal::make_FunctionInterface<void, typename std::decay<Args>::type...>(args...);
auto fi = vtkm::internal::make_FunctionInterface<void, detail::remove_cvref<Args>...>(args...);
auto ivc = vtkm::internal::Invocation<ParameterInterface,
ControlInterface,
ExecutionInterface,

@ -0,0 +1,33 @@
//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//============================================================================
#ifndef vtk_m_worklet_internal_ScatterBase_h
#define vtk_m_worklet_internal_ScatterBase_h
#include <vtkm/internal/ExportMacros.h>
namespace vtkm
{
namespace worklet
{
namespace internal
{
/// Base class for all scatter classes.
///
/// This allows VTK-m to determine when a parameter
/// is a scatter type instead of a worklet parameter.
///
struct VTKM_ALWAYS_EXPORT ScatterBase
{
};
}
}
}
#endif