Merge topic 'invoker_support_scatter'
18a0cd35b vtkm::worklet::Invoker now supports scatter types afc3f5304 Remove unneeded ScatterType as it was the default a1ea509f8 All scatter types now inherit from a common base 77378993f DispatcherBase: Simplify remove_cvref and remove_pointer_and_decay. Acked-by: Kitware Robot <kwrobot@kitware.com> Acked-by: Kenneth Moreland <kmorel@sandia.gov> Merge-request: !1673
This commit is contained in:
commit
be5b51fb01
21
docs/changelog/invoker-supports-scatter-types.md
Normal file
21
docs/changelog/invoker-supports-scatter-types.md
Normal file
@ -0,0 +1,21 @@
|
||||
# `vtkm::worklet::Invoker` now able to worklets that have non-default scatter type
|
||||
|
||||
This change allows the `Invoker` class to support launching worklets that require
|
||||
a custom scatter operation. This is done by providing the scatter as the second
|
||||
argument when launch a worklet with the `()` operator.
|
||||
|
||||
The following example shows a scatter being provided with a worklet launch.
|
||||
|
||||
```cpp
|
||||
struct CheckTopology : vtkm::worklet::WorkletMapPointToCell
|
||||
{
|
||||
using ControlSignature = void(CellSetIn cellset, FieldOutCell);
|
||||
using ExecutionSignature = _2(FromIndices);
|
||||
using ScatterType = vtkm::worklet::ScatterPermutation<>;
|
||||
...
|
||||
};
|
||||
|
||||
|
||||
vtkm::worklet::Ivoker invoke;
|
||||
invoke( CheckTopology{}, vtkm::worklet::ScatterPermutation{}, cellset, result );
|
||||
```
|
@ -808,8 +808,6 @@ public:
|
||||
|
||||
using ExecutionSignature = void(_2, _3);
|
||||
|
||||
using ScatterType = vtkm::worklet::ScatterIdentity;
|
||||
|
||||
template <typename MappedValueVecType, typename MappedValueType>
|
||||
VTKM_EXEC void operator()(const MappedValueVecType& toReduce, MappedValueType& centroid) const
|
||||
{
|
||||
|
@ -52,18 +52,44 @@ struct Invoker
|
||||
{
|
||||
}
|
||||
|
||||
/// Launch the worklet that is provided as the first parameter. The additional
|
||||
/// parameters are the ControlSignature arguments for the worklet.
|
||||
/// Launch the worklet that is provided as the first parameter.
|
||||
/// Optional second parameter is the scatter type associated with the worklet.
|
||||
/// Any additional parameters are the ControlSignature arguments for the worklet.
|
||||
///
|
||||
template <typename Worklet, typename... Args>
|
||||
inline void operator()(Worklet&& worklet, Args&&... args) const
|
||||
template <typename Worklet,
|
||||
typename T,
|
||||
typename... Args,
|
||||
typename std::enable_if<
|
||||
std::is_base_of<internal::ScatterBase, internal::detail::remove_cvref<T>>::value,
|
||||
int>::type* = nullptr>
|
||||
inline void operator()(Worklet&& worklet, T&& scatter, Args&&... args) const
|
||||
{
|
||||
using WorkletType = typename std::decay<Worklet>::type;
|
||||
using WorkletType = internal::detail::remove_cvref<Worklet>;
|
||||
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
|
||||
|
||||
DispatcherType dispatcher(worklet, scatter);
|
||||
dispatcher.SetDevice(this->DeviceId);
|
||||
dispatcher.Invoke(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/// Launch the worklet that is provided as the first parameter.
|
||||
/// Optional second parameter is the scatter type associated with the worklet.
|
||||
/// Any additional parameters are the ControlSignature arguments for the worklet.
|
||||
///
|
||||
template <typename Worklet,
|
||||
typename T,
|
||||
typename... Args,
|
||||
typename std::enable_if<
|
||||
!std::is_base_of<internal::ScatterBase, internal::detail::remove_cvref<T>>::value,
|
||||
int>::type* = nullptr>
|
||||
inline void operator()(Worklet&& worklet, T&& t, Args&&... args) const
|
||||
{
|
||||
using WorkletType = internal::detail::remove_cvref<Worklet>;
|
||||
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
|
||||
|
||||
DispatcherType dispatcher(worklet);
|
||||
dispatcher.SetDevice(this->DeviceId);
|
||||
dispatcher.Invoke(std::forward<Args>(args)...);
|
||||
dispatcher.Invoke(std::forward<T>(t), std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/// Get the device adapter that this Invoker is bound too
|
||||
|
@ -10,6 +10,7 @@
|
||||
#ifndef vtk_m_worklet_ScatterCounting_h
|
||||
#define vtk_m_worklet_ScatterCounting_h
|
||||
|
||||
#include <vtkm/worklet/internal/ScatterBase.h>
|
||||
#include <vtkm/worklet/vtkm_worklet_export.h>
|
||||
|
||||
#include <vtkm/cont/VariantArrayHandle.h>
|
||||
@ -40,7 +41,7 @@ struct ScatterCountingBuilder;
|
||||
/// taken in the constructor and the index arrays are derived from that. So
|
||||
/// changing the counts after the scatter is created will have no effect.
|
||||
///
|
||||
struct VTKM_WORKLET_EXPORT ScatterCounting
|
||||
struct VTKM_WORKLET_EXPORT ScatterCounting : internal::ScatterBase
|
||||
{
|
||||
struct CountTypes : vtkm::ListTagBase<vtkm::Int64,
|
||||
vtkm::Int32,
|
||||
|
@ -12,6 +12,7 @@
|
||||
|
||||
#include <vtkm/cont/ArrayHandleConstant.h>
|
||||
#include <vtkm/cont/ArrayHandleIndex.h>
|
||||
#include <vtkm/worklet/internal/ScatterBase.h>
|
||||
|
||||
namespace vtkm
|
||||
{
|
||||
@ -26,7 +27,7 @@ namespace worklet
|
||||
/// element generates one output element associated with it. This is the
|
||||
/// default for basic maps.
|
||||
///
|
||||
struct ScatterIdentity
|
||||
struct ScatterIdentity : internal::ScatterBase
|
||||
{
|
||||
using OutputToInputMapType = vtkm::cont::ArrayHandleIndex;
|
||||
VTKM_CONT
|
||||
|
@ -12,6 +12,7 @@
|
||||
|
||||
#include <vtkm/cont/ArrayHandle.h>
|
||||
#include <vtkm/cont/ArrayHandleConstant.h>
|
||||
#include <vtkm/worklet/internal/ScatterBase.h>
|
||||
|
||||
namespace vtkm
|
||||
{
|
||||
@ -28,7 +29,7 @@ namespace worklet
|
||||
/// can be duplicates. Note that even with duplicates the VistIndex is always 0.
|
||||
///
|
||||
template <typename PermutationStorage = VTKM_DEFAULT_STORAGE_TAG>
|
||||
class ScatterPermutation
|
||||
class ScatterPermutation : public internal::ScatterBase
|
||||
{
|
||||
private:
|
||||
using PermutationArrayHandle = vtkm::cont::ArrayHandle<vtkm::Id, PermutationStorage>;
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <vtkm/cont/ArrayHandle.h>
|
||||
#include <vtkm/cont/ArrayHandleCounting.h>
|
||||
#include <vtkm/cont/ArrayHandleImplicit.h>
|
||||
#include <vtkm/worklet/internal/ScatterBase.h>
|
||||
|
||||
namespace vtkm
|
||||
{
|
||||
@ -49,7 +50,7 @@ struct FunctorDiv
|
||||
/// elements are grouped by the input associated.
|
||||
///
|
||||
template <vtkm::IdComponent NumOutputsPerInput>
|
||||
struct ScatterUniform
|
||||
struct ScatterUniform : internal::ScatterBase
|
||||
{
|
||||
VTKM_CONT ScatterUniform() = default;
|
||||
|
||||
|
@ -76,7 +76,7 @@
|
||||
#include <vtkm/cont/ArrayHandleIndex.h>
|
||||
#include <vtkm/cont/ArrayHandleReverse.h>
|
||||
#include <vtkm/cont/ArrayHandleTransform.h>
|
||||
#include <vtkm/worklet/DispatcherMapField.h>
|
||||
#include <vtkm/worklet/Invoker.h>
|
||||
#include <vtkm/worklet/ScatterCounting.h>
|
||||
|
||||
#include <vtkm/BinaryPredicates.h>
|
||||
|
@ -316,16 +316,15 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
|
||||
|
||||
// Setup the ScatterCounting worklets needed to expand the ReduceByKeyResults
|
||||
vtkm::worklet::ScatterCounting scatter(particlesPerHalo);
|
||||
vtkm::worklet::DispatcherMapField<ScatterWorklet<vtkm::Id>> scatterWorkletIdDispatcher(scatter);
|
||||
vtkm::worklet::DispatcherMapField<ScatterWorklet<T>> scatterWorkletDispatcher(scatter);
|
||||
vtkm::worklet::Invoker invoke;
|
||||
|
||||
// Calculate the minimum particle index per halo id and scatter
|
||||
DeviceAlgorithm::ScanExclusive(particlesPerHalo, tempI);
|
||||
scatterWorkletIdDispatcher.Invoke(tempI, minParticle);
|
||||
invoke(ScatterWorklet<vtkm::Id>{}, scatter, tempI, minParticle);
|
||||
|
||||
// Calculate the maximum particle index per halo id and scatter
|
||||
DeviceAlgorithm::ScanInclusive(particlesPerHalo, tempI);
|
||||
scatterWorkletIdDispatcher.Invoke(tempI, maxParticle);
|
||||
invoke(ScatterWorklet<vtkm::Id>{}, scatter, tempI, maxParticle);
|
||||
|
||||
using IdArrayType = vtkm::cont::ArrayHandle<vtkm::Id>;
|
||||
vtkm::cont::ArrayHandleTransform<IdArrayType, ScaleBiasFunctor<vtkm::Id>> scaleBias =
|
||||
@ -354,7 +353,7 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
|
||||
|
||||
// Find minimum potential for all particles in a halo and scatter
|
||||
DeviceAlgorithm::ReduceByKey(haloId, potential, uniqueHaloIds, tempT, vtkm::Minimum());
|
||||
scatterWorkletDispatcher.Invoke(tempT, minPotential);
|
||||
invoke(ScatterWorklet<T>{}, scatter, tempT, minPotential);
|
||||
#ifdef DEBUG_PRINT
|
||||
DebugPrint("potential", potential);
|
||||
DebugPrint("minPotential", minPotential);
|
||||
@ -371,7 +370,7 @@ void CosmoTools<T, StorageType>::MBPCenterFindingByHalo(vtkm::cont::ArrayHandle<
|
||||
vtkm::cont::ArrayHandle<vtkm::Id> minIndx;
|
||||
minIndx.Allocate(nParticles);
|
||||
DeviceAlgorithm::ReduceByKey(haloId, mbpId, uniqueHaloIds, minIndx, vtkm::Maximum());
|
||||
scatterWorkletIdDispatcher.Invoke(minIndx, mbpId);
|
||||
invoke(ScatterWorklet<vtkm::Id>{}, scatter, minIndx, mbpId);
|
||||
|
||||
// Resort particle ids and mbpId to starting order
|
||||
vtkm::cont::ArrayHandle<vtkm::Id> savePartId;
|
||||
|
@ -10,6 +10,7 @@
|
||||
|
||||
set(headers
|
||||
DispatcherBase.h
|
||||
ScatterBase.h
|
||||
TriangulateTables.h
|
||||
WorkletBase.h
|
||||
)
|
||||
|
@ -153,15 +153,16 @@ struct ReportValueOnError<Value, true> : std::true_type
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct remove_pointer_and_decay : std::remove_pointer<typename std::decay<T>::type>
|
||||
{
|
||||
};
|
||||
using remove_pointer_and_decay = typename std::remove_pointer<typename std::decay<T>::type>::type;
|
||||
|
||||
template <typename T>
|
||||
using remove_cvref = typename std::remove_cv<typename std::remove_reference<T>::type>::type;
|
||||
|
||||
// Is designed as a brigand fold operation.
|
||||
template <typename Type, typename State>
|
||||
struct DetermineIfHasDynamicParameter
|
||||
{
|
||||
using T = typename std::remove_pointer<Type>::type;
|
||||
using T = remove_pointer_and_decay<Type>;
|
||||
using DynamicTag = typename vtkm::cont::internal::DynamicTransformTraits<T>::DynamicTag;
|
||||
using isDynamic =
|
||||
typename std::is_same<DynamicTag, vtkm::cont::internal::DynamicTransformTagCastAndCall>::type;
|
||||
@ -314,7 +315,7 @@ struct DispatcherBaseTransportFunctor
|
||||
{
|
||||
using TransportTag =
|
||||
typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
|
||||
using T = typename remove_pointer_and_decay<ControlParameter>::type;
|
||||
using T = remove_pointer_and_decay<ControlParameter>;
|
||||
using TransportType = typename vtkm::cont::arg::Transport<TransportTag, T, Device>;
|
||||
using type = typename TransportType::ExecObjectType;
|
||||
};
|
||||
@ -326,7 +327,7 @@ struct DispatcherBaseTransportFunctor
|
||||
{
|
||||
using TransportTag =
|
||||
typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
|
||||
using T = typename remove_pointer_and_decay<ControlParameter>::type;
|
||||
using T = remove_pointer_and_decay<ControlParameter>;
|
||||
vtkm::cont::arg::Transport<TransportTag, T, Device> transport;
|
||||
|
||||
not_nullptr(invokeData, Index);
|
||||
@ -412,7 +413,7 @@ struct for_each_dynamic_arg
|
||||
void operator()(const Trampoline& trampoline, ContParams&& sig, T&& t, Args&&... args) const
|
||||
{
|
||||
//Determine that state of T when it is either a `cons&` or a `* const&`
|
||||
using Type = typename std::remove_pointer<typename std::decay<T>::type>::type;
|
||||
using Type = remove_pointer_and_decay<T>;
|
||||
using tag = typename vtkm::cont::internal::DynamicTransformTraits<Type>::DynamicTag;
|
||||
//convert the first item to a known type
|
||||
convert_arg<LeftToProcess>(
|
||||
@ -494,7 +495,7 @@ private:
|
||||
VTKM_CONT void StartInvoke(Args&&... args) const
|
||||
{
|
||||
using ParameterInterface =
|
||||
vtkm::internal::FunctionInterface<void(typename std::decay<Args>::type...)>;
|
||||
vtkm::internal::FunctionInterface<void(detail::remove_cvref<Args>...)>;
|
||||
|
||||
VTKM_STATIC_ASSERT_MSG(ParameterInterface::ARITY == NUM_INVOKE_PARAMS,
|
||||
"Dispatcher Invoke called with wrong number of arguments.");
|
||||
@ -540,7 +541,7 @@ private:
|
||||
VTKM_CONT void StartInvokeDynamic(std::false_type, Args&&... args) const
|
||||
{
|
||||
using ParameterInterface =
|
||||
vtkm::internal::FunctionInterface<void(typename std::decay<Args>::type...)>;
|
||||
vtkm::internal::FunctionInterface<void(detail::remove_cvref<Args>...)>;
|
||||
|
||||
//Nothing requires a conversion from dynamic to static types, so
|
||||
//next we need to verify that each argument's type is correct. If not
|
||||
@ -561,8 +562,7 @@ private:
|
||||
static_assert(isAllValid::value == expectedLen::value,
|
||||
"All arguments failed the TypeCheck pass");
|
||||
|
||||
auto fi =
|
||||
vtkm::internal::make_FunctionInterface<void, typename std::decay<Args>::type...>(args...);
|
||||
auto fi = vtkm::internal::make_FunctionInterface<void, detail::remove_cvref<Args>...>(args...);
|
||||
auto ivc = vtkm::internal::Invocation<ParameterInterface,
|
||||
ControlInterface,
|
||||
ExecutionInterface,
|
||||
|
33
vtkm/worklet/internal/ScatterBase.h
Normal file
33
vtkm/worklet/internal/ScatterBase.h
Normal file
@ -0,0 +1,33 @@
|
||||
//============================================================================
|
||||
// Copyright (c) Kitware, Inc.
|
||||
// All rights reserved.
|
||||
// See LICENSE.txt for details.
|
||||
//
|
||||
// This software is distributed WITHOUT ANY WARRANTY; without even
|
||||
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
|
||||
// PURPOSE. See the above copyright notice for more information.
|
||||
//============================================================================
|
||||
#ifndef vtk_m_worklet_internal_ScatterBase_h
|
||||
#define vtk_m_worklet_internal_ScatterBase_h
|
||||
|
||||
#include <vtkm/internal/ExportMacros.h>
|
||||
|
||||
namespace vtkm
|
||||
{
|
||||
namespace worklet
|
||||
{
|
||||
namespace internal
|
||||
{
|
||||
|
||||
/// Base class for all scatter classes.
|
||||
///
|
||||
/// This allows VTK-m to determine when a parameter
|
||||
/// is a scatter type instead of a worklet parameter.
|
||||
///
|
||||
struct VTKM_ALWAYS_EXPORT ScatterBase
|
||||
{
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
Loading…
Reference in New Issue
Block a user