vtkm::cont::Invoker supports both Masks and Scatter

Fixes #420
This commit is contained in:
Robert Maynard 2019-09-27 16:40:53 -04:00
parent 711efbe3b8
commit 445e4d8186
10 changed files with 156 additions and 58 deletions

@ -22,6 +22,13 @@ namespace vtkm
namespace cont namespace cont
{ {
namespace detail
{
template <typename T>
using scatter_or_mask = std::integral_constant<bool,
vtkm::worklet::internal::is_mask<T>::value ||
vtkm::worklet::internal::is_scatter<T>::value>;
}
/// \brief Allows launching any worklet without a dispatcher. /// \brief Allows launching any worklet without a dispatcher.
/// ///
@ -53,40 +60,59 @@ struct Invoker
} }
/// Launch the worklet that is provided as the first parameter. /// Launch the worklet that is provided as the first parameter.
/// Optional second parameter is the scatter type associated with the worklet. /// Optional second parameter is either the scatter or mask type associated with the worklet.
/// Any additional parameters are the ControlSignature arguments for the worklet. /// Any additional parameters are the ControlSignature arguments for the worklet.
/// ///
template < template <typename Worklet,
typename Worklet, typename T,
typename T, typename... Args,
typename... Args, typename std::enable_if<detail::scatter_or_mask<T>::value, int>::type* = nullptr>
typename std::enable_if<std::is_base_of<worklet::internal::ScatterBase, inline void operator()(Worklet&& worklet, T&& scatterOrMask, Args&&... args) const
worklet::internal::detail::remove_cvref<T>>::value,
int>::type* = nullptr>
inline void operator()(Worklet&& worklet, T&& scatter, Args&&... args) const
{ {
using WorkletType = worklet::internal::detail::remove_cvref<Worklet>; using WorkletType = worklet::internal::remove_cvref<Worklet>;
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>; using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
DispatcherType dispatcher(worklet, scatter); DispatcherType dispatcher(worklet, scatterOrMask);
dispatcher.SetDevice(this->DeviceId); dispatcher.SetDevice(this->DeviceId);
dispatcher.Invoke(std::forward<Args>(args)...); dispatcher.Invoke(std::forward<Args>(args)...);
} }
/// Launch the worklet that is provided as the first parameter. /// Launch the worklet that is provided as the first parameter.
/// Optional second parameter is the scatter type associated with the worklet. /// Optional second parameter is either the scatter or mask type associated with the worklet.
/// Optional third parameter is either the scatter or mask type associated with the worklet.
/// Any additional parameters are the ControlSignature arguments for the worklet. /// Any additional parameters are the ControlSignature arguments for the worklet.
/// ///
template < template <
typename Worklet, typename Worklet,
typename T, typename T,
typename U,
typename... Args, typename... Args,
typename std::enable_if<!std::is_base_of<worklet::internal::ScatterBase, typename std::enable_if<detail::scatter_or_mask<T>::value && detail::scatter_or_mask<U>::value,
worklet::internal::detail::remove_cvref<T>>::value,
int>::type* = nullptr> int>::type* = nullptr>
inline void operator()(Worklet&& worklet,
T&& scatterOrMaskA,
U&& scatterOrMaskB,
Args&&... args) const
{
using WorkletType = worklet::internal::remove_cvref<Worklet>;
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
DispatcherType dispatcher(worklet, scatterOrMaskA, scatterOrMaskB);
dispatcher.SetDevice(this->DeviceId);
dispatcher.Invoke(std::forward<Args>(args)...);
}
/// Launch the worklet that is provided as the first parameter.
/// Optional second parameter is either the scatter or mask type associated with the worklet.
/// Any additional parameters are the ControlSignature arguments for the worklet.
///
template <typename Worklet,
typename T,
typename... Args,
typename std::enable_if<!detail::scatter_or_mask<T>::value, int>::type* = nullptr>
inline void operator()(Worklet&& worklet, T&& t, Args&&... args) const inline void operator()(Worklet&& worklet, T&& t, Args&&... args) const
{ {
using WorkletType = worklet::internal::detail::remove_cvref<Worklet>; using WorkletType = worklet::internal::remove_cvref<Worklet>;
using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>; using DispatcherType = typename WorkletType::template Dispatcher<WorkletType>;
DispatcherType dispatcher(worklet); DispatcherType dispatcher(worklet);

@ -11,6 +11,7 @@
#define vtk_m_worklet_MaskIndices_h #define vtk_m_worklet_MaskIndices_h
#include <vtkm/cont/Algorithm.h> #include <vtkm/cont/Algorithm.h>
#include <vtkm/worklet/internal/MaskBase.h>
namespace vtkm namespace vtkm
{ {
@ -26,7 +27,7 @@ namespace worklet
/// It is OK to give indices that are out of order, but any index must be provided at most one /// It is OK to give indices that are out of order, but any index must be provided at most one
/// time. It is an error to have the same index listed twice. /// time. It is an error to have the same index listed twice.
/// ///
class MaskIndices class MaskIndices : public internal::MaskBase
{ {
public: public:
using ThreadToOutputMapType = vtkm::cont::ArrayHandle<vtkm::Id>; using ThreadToOutputMapType = vtkm::cont::ArrayHandle<vtkm::Id>;

@ -11,6 +11,7 @@
#define vtk_m_worklet_MaskNone_h #define vtk_m_worklet_MaskNone_h
#include <vtkm/cont/ArrayHandleIndex.h> #include <vtkm/cont/ArrayHandleIndex.h>
#include <vtkm/worklet/internal/MaskBase.h>
namespace vtkm namespace vtkm
{ {
@ -23,7 +24,7 @@ namespace worklet
/// domain. This is the default mask object so that the worklet is run for every possible /// domain. This is the default mask object so that the worklet is run for every possible
/// output element. /// output element.
/// ///
struct MaskNone struct MaskNone : public internal::MaskBase
{ {
template <typename RangeType> template <typename RangeType>
VTKM_CONT RangeType GetThreadRange(RangeType outputRange) const VTKM_CONT RangeType GetThreadRange(RangeType outputRange) const

@ -10,6 +10,7 @@
#ifndef vtk_m_worklet_MaskSelect_h #ifndef vtk_m_worklet_MaskSelect_h
#define vtk_m_worklet_MaskSelect_h #define vtk_m_worklet_MaskSelect_h
#include <vtkm/worklet/internal/MaskBase.h>
#include <vtkm/worklet/vtkm_worklet_export.h> #include <vtkm/worklet/vtkm_worklet_export.h>
#include <vtkm/cont/VariantArrayHandle.h> #include <vtkm/cont/VariantArrayHandle.h>
@ -21,7 +22,7 @@ namespace worklet
/// \brief Mask using arrays to select specific elements to suppress. /// \brief Mask using arrays to select specific elements to suppress.
/// ///
/// \c MaskSelect is a worklet mask object that is used to select elementsin the output of a /// \c MaskSelect is a worklet mask object that is used to select elements in the output of a
/// worklet to suppress the invocation. That is, the worklet will only be invoked for elements in /// worklet to suppress the invocation. That is, the worklet will only be invoked for elements in
/// the output that are not masked out by the given array. /// the output that are not masked out by the given array.
/// ///
@ -29,7 +30,7 @@ namespace worklet
/// that should be masked and a 1 for any output that should be generated. It is an error to have /// that should be masked and a 1 for any output that should be generated. It is an error to have
/// any value that is not a 0 or 1. This method is slower than specifying an index array. /// any value that is not a 0 or 1. This method is slower than specifying an index array.
/// ///
class VTKM_WORKLET_EXPORT MaskSelect class VTKM_WORKLET_EXPORT MaskSelect : public internal::MaskBase
{ {
struct MaskTypes : vtkm::ListTagBase<vtkm::Int32, struct MaskTypes : vtkm::ListTagBase<vtkm::Int32,
vtkm::Int64, vtkm::Int64,

@ -30,10 +30,9 @@
#include <vtkm/cont/ArrayHandleTransform.h> #include <vtkm/cont/ArrayHandleTransform.h>
#include <vtkm/cont/ArrayRangeCompute.h> #include <vtkm/cont/ArrayRangeCompute.h>
#include <vtkm/cont/BitField.h> #include <vtkm/cont/BitField.h>
#include <vtkm/cont/Invoker.h>
#include <vtkm/cont/Logging.h> #include <vtkm/cont/Logging.h>
#include <vtkm/worklet/DispatcherMapField.h>
#include <vtkm/worklet/DispatcherMapTopology.h>
#include <vtkm/worklet/MaskIndices.h> #include <vtkm/worklet/MaskIndices.h>
#include <vtkm/worklet/WorkletMapField.h> #include <vtkm/worklet/WorkletMapField.h>
#include <vtkm/worklet/WorkletMapTopology.h> #include <vtkm/worklet/WorkletMapTopology.h>
@ -336,12 +335,6 @@ public:
{ {
using RangeType = vtkm::cont::ArrayHandle<vtkm::Range>; using RangeType = vtkm::cont::ArrayHandle<vtkm::Range>;
using MarkSourcePoints = vtkm::worklet::DispatcherMapField<WorkletMarkSourcePoints>;
using ProcessSourceCells = vtkm::worklet::DispatcherMapTopology<WorkletProcessSourceCells>;
using MarkActivePoints = vtkm::worklet::DispatcherMapTopology<WorkletMarkActivePoints>;
using MarkActiveCells = vtkm::worklet::DispatcherMapTopology<WorkletMarkActiveCells>;
using ProcessCellNormals = vtkm::worklet::DispatcherMapField<WorkletProcessCellNormals>;
const vtkm::Id numPoints = coords.GetNumberOfValues(); const vtkm::Id numPoints = coords.GetNumberOfValues();
const vtkm::Id numCells = cells.GetNumberOfCells(); const vtkm::Id numCells = cells.GetNumberOfCells();
@ -367,6 +360,7 @@ public:
vtkm::cont::Algorithm::Fill(visitedCellBits, false, numCells); vtkm::cont::Algorithm::Fill(visitedCellBits, false, numCells);
auto visitedCells = vtkm::cont::make_ArrayHandleBitField(visitedCellBits); auto visitedCells = vtkm::cont::make_ArrayHandleBitField(visitedCellBits);
vtkm::cont::Invoker invoke;
vtkm::cont::ArrayHandle<vtkm::Id> mask; // Allocated as needed vtkm::cont::ArrayHandle<vtkm::Id> mask; // Allocated as needed
// For each cell, store a reference alignment cell. // For each cell, store a reference alignment cell.
@ -381,10 +375,7 @@ public:
// 2) Locate points on a boundary, since their normal alignment direction // 2) Locate points on a boundary, since their normal alignment direction
// is known. // is known.
{ invoke(WorkletMarkSourcePoints{}, coords, ranges, activePoints);
MarkSourcePoints dispatcher;
dispatcher.Invoke(coords, ranges, activePoints);
}
// 3) For each source point, align the normals of the adjacent cells. // 3) For each source point, align the normals of the adjacent cells.
{ {
@ -392,15 +383,16 @@ public:
(void)numActive; (void)numActive;
VTKM_LOG_S(vtkm::cont::LogLevel::Perf, VTKM_LOG_S(vtkm::cont::LogLevel::Perf,
"ProcessSourceCells from " << numActive << " source points."); "ProcessSourceCells from " << numActive << " source points.");
ProcessSourceCells dispatcher{ vtkm::worklet::MaskIndices{ mask } }; invoke(WorkletProcessSourceCells{},
dispatcher.Invoke(cells, vtkm::worklet::MaskIndices{ mask },
coords, cells,
ranges, coords,
cellNormals, ranges,
activeCellBits, cellNormals,
visitedCellBits, activeCellBits,
activePoints, visitedCellBits,
visitedPoints); activePoints,
visitedPoints);
} }
for (size_t iter = 1;; ++iter) for (size_t iter = 1;; ++iter)
@ -411,8 +403,12 @@ public:
(void)numActive; (void)numActive;
VTKM_LOG_S(vtkm::cont::LogLevel::Perf, VTKM_LOG_S(vtkm::cont::LogLevel::Perf,
"MarkActivePoints from " << numActive << " active cells."); "MarkActivePoints from " << numActive << " active cells.");
MarkActivePoints dispatcher{ vtkm::worklet::MaskIndices{ mask } }; invoke(WorkletMarkActivePoints{},
dispatcher.Invoke(cells, activePointBits, visitedPointBits, activeCells); vtkm::worklet::MaskIndices{ mask },
cells,
activePointBits,
visitedPointBits,
activeCells);
} }
// 5) Mark unvisited cells adjacent to active points // 5) Mark unvisited cells adjacent to active points
@ -421,8 +417,13 @@ public:
(void)numActive; (void)numActive;
VTKM_LOG_S(vtkm::cont::LogLevel::Perf, VTKM_LOG_S(vtkm::cont::LogLevel::Perf,
"MarkActiveCells from " << numActive << " active points."); "MarkActiveCells from " << numActive << " active points.");
MarkActiveCells dispatcher{ vtkm::worklet::MaskIndices{ mask } }; invoke(WorkletMarkActiveCells{},
dispatcher.Invoke(cells, refCells, activeCellBits, visitedCellBits, activePoints); vtkm::worklet::MaskIndices{ mask },
cells,
refCells,
activeCellBits,
visitedCellBits,
activePoints);
} }
vtkm::Id numActiveCells = vtkm::cont::Algorithm::BitFieldToUnorderedSet(activeCellBits, mask); vtkm::Id numActiveCells = vtkm::cont::Algorithm::BitFieldToUnorderedSet(activeCellBits, mask);
@ -438,8 +439,11 @@ public:
// 5) Correct normals for active cells. // 5) Correct normals for active cells.
{ {
ProcessCellNormals dispatcher{ vtkm::worklet::MaskIndices{ mask } }; invoke(WorkletProcessCellNormals{},
dispatcher.Invoke(refCells, cellNormals, visitedCells); vtkm::worklet::MaskIndices{ mask },
refCells,
cellNormals,
visitedCells);
} }
} }
} }

@ -9,7 +9,9 @@
##============================================================================ ##============================================================================
set(headers set(headers
DecayHelpers.h
DispatcherBase.h DispatcherBase.h
MaskBase.h
ScatterBase.h ScatterBase.h
TriangulateTables.h TriangulateTables.h
WorkletBase.h WorkletBase.h

@ -0,0 +1,30 @@
//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//============================================================================
#ifndef vtk_m_worklet_internal_DecayHelpers_h
#define vtk_m_worklet_internal_DecayHelpers_h
#include <type_traits>
namespace vtkm
{
namespace worklet
{
namespace internal
{
template <typename T>
using remove_pointer_and_decay = typename std::remove_pointer<typename std::decay<T>::type>::type;
template <typename T>
using remove_cvref = typename std::remove_cv<typename std::remove_reference<T>::type>::type;
}
}
}
#endif

@ -29,6 +29,7 @@
#include <vtkm/internal/brigand.hpp> #include <vtkm/internal/brigand.hpp>
#include <vtkm/worklet/internal/DecayHelpers.h>
#include <vtkm/worklet/internal/WorkletBase.h> #include <vtkm/worklet/internal/WorkletBase.h>
#include <sstream> #include <sstream>
@ -168,12 +169,6 @@ struct ReportValueOnError<Value, true> : std::true_type
{ {
}; };
template <typename T>
using remove_pointer_and_decay = typename std::remove_pointer<typename std::decay<T>::type>::type;
template <typename T>
using remove_cvref = typename std::remove_cv<typename std::remove_reference<T>::type>::type;
// Is designed as a brigand fold operation. // Is designed as a brigand fold operation.
template <typename Type, typename State> template <typename Type, typename State>
struct DetermineIfHasDynamicParameter struct DetermineIfHasDynamicParameter
@ -510,8 +505,7 @@ private:
template <typename... Args> template <typename... Args>
VTKM_CONT void StartInvoke(Args&&... args) const VTKM_CONT void StartInvoke(Args&&... args) const
{ {
using ParameterInterface = using ParameterInterface = vtkm::internal::FunctionInterface<void(remove_cvref<Args>...)>;
vtkm::internal::FunctionInterface<void(detail::remove_cvref<Args>...)>;
VTKM_STATIC_ASSERT_MSG(ParameterInterface::ARITY == NUM_INVOKE_PARAMS, VTKM_STATIC_ASSERT_MSG(ParameterInterface::ARITY == NUM_INVOKE_PARAMS,
"Dispatcher Invoke called with wrong number of arguments."); "Dispatcher Invoke called with wrong number of arguments.");
@ -556,8 +550,7 @@ private:
template <typename... Args> template <typename... Args>
VTKM_CONT void StartInvokeDynamic(std::false_type, Args&&... args) const VTKM_CONT void StartInvokeDynamic(std::false_type, Args&&... args) const
{ {
using ParameterInterface = using ParameterInterface = vtkm::internal::FunctionInterface<void(remove_cvref<Args>...)>;
vtkm::internal::FunctionInterface<void(detail::remove_cvref<Args>...)>;
//Nothing requires a conversion from dynamic to static types, so //Nothing requires a conversion from dynamic to static types, so
//next we need to verify that each argument's type is correct. If not //next we need to verify that each argument's type is correct. If not
@ -578,7 +571,7 @@ private:
static_assert(isAllValid::value == expectedLen::value, static_assert(isAllValid::value == expectedLen::value,
"All arguments failed the TypeCheck pass"); "All arguments failed the TypeCheck pass");
auto fi = vtkm::internal::make_FunctionInterface<void, detail::remove_cvref<Args>...>(args...); auto fi = vtkm::internal::make_FunctionInterface<void, remove_cvref<Args>...>(args...);
auto ivc = vtkm::internal::Invocation<ParameterInterface, auto ivc = vtkm::internal::Invocation<ParameterInterface,
ControlInterface, ControlInterface,
ExecutionInterface, ExecutionInterface,

@ -0,0 +1,37 @@
//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//============================================================================
#ifndef vtk_m_worklet_internal_MaskBase_h
#define vtk_m_worklet_internal_MaskBase_h
#include <vtkm/internal/ExportMacros.h>
#include <vtkm/worklet/internal/DecayHelpers.h>
namespace vtkm
{
namespace worklet
{
namespace internal
{
/// Base class for all mask classes.
///
/// This allows VTK-m to determine when a parameter
/// is a mask type instead of a worklet parameter.
///
struct VTKM_ALWAYS_EXPORT MaskBase
{
};
template <typename T>
using is_mask = std::is_base_of<MaskBase, remove_cvref<T>>;
}
}
}
#endif

@ -11,6 +11,7 @@
#define vtk_m_worklet_internal_ScatterBase_h #define vtk_m_worklet_internal_ScatterBase_h
#include <vtkm/internal/ExportMacros.h> #include <vtkm/internal/ExportMacros.h>
#include <vtkm/worklet/internal/DecayHelpers.h>
namespace vtkm namespace vtkm
{ {
@ -18,7 +19,6 @@ namespace worklet
{ {
namespace internal namespace internal
{ {
/// Base class for all scatter classes. /// Base class for all scatter classes.
/// ///
/// This allows VTK-m to determine when a parameter /// This allows VTK-m to determine when a parameter
@ -27,6 +27,9 @@ namespace internal
struct VTKM_ALWAYS_EXPORT ScatterBase struct VTKM_ALWAYS_EXPORT ScatterBase
{ {
}; };
template <typename T>
using is_scatter = std::is_base_of<ScatterBase, remove_cvref<T>>;
} }
} }
} }