mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-09-16 17:22:55 +00:00
Worklet{MapTopology,PointNeighbor} custom sg/mask
The change affects the method GetThreadIndices for both WorkletMapTopology and WorkletPointNeighborhood. Before an scatter or mask which was not ScatterIdentity or MaskNone was not allowed and it was enforced at compilation time. Signed-off-by: Vicente Adolfo Bolea Sanchez <vicente.bolea@kitware.com>
This commit is contained in:
parent
fd9c21c0d4
commit
0e90c22e70
@ -53,6 +53,7 @@ public:
|
||||
//this->CellShape = connectivity.GetCellShape(index);
|
||||
this->GlobalThreadIndexOffset = globalThreadIndexOffset;
|
||||
}
|
||||
|
||||
VTKM_SUPPRESS_EXEC_WARNINGS
|
||||
VTKM_EXEC
|
||||
ThreadIndicesTopologyMap(const vtkm::Id3& threadIndex3D,
|
||||
@ -60,9 +61,7 @@ public:
|
||||
const ConnectivityType& connectivity,
|
||||
vtkm::Id globalThreadIndexOffset = 0)
|
||||
{
|
||||
// We currently only support multidimensional indices on one-to-one input-
|
||||
// to-output mappings. (We don't have a use case otherwise.)
|
||||
// That is why we treat teh threadIndex as also the inputIndex and outputIndex
|
||||
// This constructor handles multidimensional indices on one-to-one input-to-output
|
||||
auto logicalIndex = detail::Deflate(threadIndex3D, LogicalIndexType());
|
||||
|
||||
this->ThreadIndex = threadIndex1D;
|
||||
@ -75,6 +74,29 @@ public:
|
||||
this->GlobalThreadIndexOffset = globalThreadIndexOffset;
|
||||
}
|
||||
|
||||
VTKM_SUPPRESS_EXEC_WARNINGS
|
||||
VTKM_EXEC
|
||||
ThreadIndicesTopologyMap(const vtkm::Id3& threadIndex3D,
|
||||
vtkm::Id threadIndex1D,
|
||||
vtkm::Id inputIndex,
|
||||
vtkm::IdComponent visitIndex,
|
||||
vtkm::Id outputIndex,
|
||||
const ConnectivityType& connectivity,
|
||||
vtkm::Id globalThreadIndexOffset = 0)
|
||||
{
|
||||
// This constructor handles multidimensional indices on many-to-many input-to-output
|
||||
auto logicalIndex = detail::Deflate(threadIndex3D, LogicalIndexType());
|
||||
|
||||
this->ThreadIndex = threadIndex1D;
|
||||
this->InputIndex = inputIndex;
|
||||
this->OutputIndex = outputIndex;
|
||||
this->VisitIndex = visitIndex;
|
||||
this->LogicalIndex = logicalIndex;
|
||||
this->IndicesIncident = connectivity.GetIndices(logicalIndex);
|
||||
//this->CellShape = connectivity.GetCellShape(index);
|
||||
this->GlobalThreadIndexOffset = globalThreadIndexOffset;
|
||||
}
|
||||
|
||||
/// \brief The logical index into the input domain.
|
||||
///
|
||||
/// This is similar to \c GetIndex3D except the Vec size matches the actual
|
||||
@ -213,6 +235,28 @@ public:
|
||||
//this->CellShape = connectivity.GetCellShape(index);
|
||||
this->GlobalThreadIndexOffset = globalThreadIndexOffset;
|
||||
}
|
||||
|
||||
ThreadIndicesTopologyMap(const vtkm::Id3& threadIndex3D,
|
||||
vtkm::Id threadIndex1D,
|
||||
vtkm::Id inputIndex,
|
||||
vtkm::IdComponent visitIndex,
|
||||
vtkm::Id outputIndex,
|
||||
const ConnectivityType& connectivity,
|
||||
vtkm::Id globalThreadIndexOffset = 0)
|
||||
{
|
||||
|
||||
const LogicalIndexType logicalIndex = detail::Deflate(threadIndex3D, LogicalIndexType());
|
||||
|
||||
this->ThreadIndex = threadIndex1D;
|
||||
this->InputIndex = inputIndex;
|
||||
this->OutputIndex = outputIndex;
|
||||
this->VisitIndex = visitIndex;
|
||||
this->LogicalIndex = logicalIndex;
|
||||
this->IndicesIncident = connectivity.GetIndices(logicalIndex);
|
||||
//this->CellShape = connectivity.GetCellShape(index);
|
||||
this->GlobalThreadIndexOffset = globalThreadIndexOffset;
|
||||
}
|
||||
|
||||
/// \brief The logical index into the input domain.
|
||||
///
|
||||
/// This is similar to \c GetIndex3D except the Vec size matches the actual
|
||||
|
@ -72,6 +72,26 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
template <vtkm::IdComponent Dimension>
|
||||
VTKM_EXEC ThreadIndicesPointNeighborhood(
|
||||
const vtkm::Id3& threadIndex3D,
|
||||
vtkm::Id threadIndex1D,
|
||||
vtkm::Id inputIndex,
|
||||
vtkm::IdComponent visitIndex,
|
||||
vtkm::Id outputIndex,
|
||||
const vtkm::exec::ConnectivityStructured<vtkm::TopologyElementTagPoint,
|
||||
vtkm::TopologyElementTagCell,
|
||||
Dimension>& connectivity,
|
||||
vtkm::Id globalThreadIndexOffset = 0)
|
||||
: State(threadIndex3D, detail::To3D(connectivity.GetPointDimensions()))
|
||||
, ThreadIndex(threadIndex1D)
|
||||
, InputIndex(inputIndex)
|
||||
, OutputIndex(outputIndex)
|
||||
, VisitIndex(visitIndex)
|
||||
, GlobalThreadIndexOffset(globalThreadIndexOffset)
|
||||
{
|
||||
}
|
||||
|
||||
template <vtkm::IdComponent Dimension>
|
||||
VTKM_EXEC ThreadIndicesPointNeighborhood(
|
||||
vtkm::Id threadIndex,
|
||||
|
@ -179,9 +179,7 @@ public:
|
||||
const ConnectivityType& connectivity,
|
||||
const vtkm::Id globalThreadIndexOffset = 0)
|
||||
{
|
||||
// We currently only support multidimensional indices on one-to-one input-
|
||||
// to-output mappings. (We don't have a use case otherwise.)
|
||||
// That is why we treat teh threadIndex as also the inputIndex and outputIndex
|
||||
// This constructor handles multidimensional indices on one-to-one input-to-output
|
||||
auto logicalIndex = detail::Deflate(threadIndex3D, LogicalIndexType());
|
||||
this->ThreadIndex = threadIndex1D;
|
||||
this->InputIndex = threadIndex1D;
|
||||
@ -193,6 +191,26 @@ public:
|
||||
this->GlobalThreadIndexOffset = globalThreadIndexOffset;
|
||||
}
|
||||
|
||||
VTKM_EXEC ThreadIndicesTopologyMap(const vtkm::Id3& threadIndex3D,
|
||||
vtkm::Id threadIndex1D,
|
||||
vtkm::Id inIndex,
|
||||
vtkm::IdComponent visitIndex,
|
||||
vtkm::Id outIndex,
|
||||
const ConnectivityType& connectivity,
|
||||
const vtkm::Id globalThreadIndexOffset = 0)
|
||||
{
|
||||
// This constructor handles multidimensional indices on many-to-many input-to-output
|
||||
auto logicalIndex = detail::Deflate(threadIndex3D, LogicalIndexType());
|
||||
this->ThreadIndex = threadIndex1D;
|
||||
this->InputIndex = inIndex;
|
||||
this->VisitIndex = visitIndex;
|
||||
this->OutputIndex = outIndex;
|
||||
this->LogicalIndex = logicalIndex;
|
||||
this->IndicesIncident = connectivity.GetIndices(logicalIndex);
|
||||
this->CellShape = connectivity.GetCellShape(threadIndex1D);
|
||||
this->GlobalThreadIndexOffset = globalThreadIndexOffset;
|
||||
}
|
||||
|
||||
/// \brief The index of the thread or work invocation.
|
||||
///
|
||||
/// This index refers to which instance of the worklet is being invoked. Every invocation of the
|
||||
|
@ -179,13 +179,31 @@ public:
|
||||
globalThreadIndexOffset);
|
||||
}
|
||||
|
||||
/// In the remaining methods and `constexpr` we determine at compilation time
|
||||
/// which method definition will be actually used for GetThreadIndices.
|
||||
///
|
||||
/// We want to avoid further function calls when we use WorkletMapTopology in which
|
||||
/// ScatterType is set as ScatterIdentity and MaskType as MaskNone.
|
||||
/// Otherwise, we call the default method defined at the bottom of this class.
|
||||
private:
|
||||
static constexpr bool IsScatterIdentity =
|
||||
std::is_same<ScatterType, vtkm::worklet::ScatterIdentity>::value;
|
||||
static constexpr bool IsMaskNone = std::is_same<MaskType, vtkm::worklet::MaskNone>::value;
|
||||
|
||||
template <bool Cond, typename ReturnType>
|
||||
using EnableFnWhen = typename std::enable_if<Cond, ReturnType>::type;
|
||||
|
||||
public:
|
||||
/// Optimized for ScatterIdentity and MaskNone
|
||||
VTKM_SUPPRESS_EXEC_WARNINGS
|
||||
template <typename OutToInArrayType,
|
||||
typename VisitArrayType,
|
||||
typename ThreadToOutArrayType,
|
||||
typename InputDomainType>
|
||||
VTKM_EXEC vtkm::exec::arg::ThreadIndicesTopologyMap<InputDomainType> GetThreadIndices(
|
||||
vtkm::Id threadIndex1D,
|
||||
typename InputDomainType,
|
||||
bool S = IsScatterIdentity,
|
||||
bool M = IsMaskNone>
|
||||
VTKM_EXEC EnableFnWhen<S && M, vtkm::exec::arg::ThreadIndicesTopologyMap<InputDomainType>>
|
||||
GetThreadIndices(vtkm::Id threadIndex1D,
|
||||
const vtkm::Id3& threadIndex3D,
|
||||
const OutToInArrayType& vtkmNotUsed(outToIn),
|
||||
const VisitArrayType& vtkmNotUsed(visit),
|
||||
@ -193,16 +211,36 @@ public:
|
||||
const InputDomainType& connectivity,
|
||||
vtkm::Id globalThreadIndexOffset = 0) const
|
||||
{
|
||||
using ScatterCheck = std::is_same<ScatterType, vtkm::worklet::ScatterIdentity>;
|
||||
VTKM_STATIC_ASSERT_MSG(ScatterCheck::value,
|
||||
"Scheduling on 3D topologies only works with default ScatterIdentity.");
|
||||
using MaskCheck = std::is_same<MaskType, vtkm::worklet::MaskNone>;
|
||||
VTKM_STATIC_ASSERT_MSG(MaskCheck::value,
|
||||
"Scheduling on 3D topologies only works with default MaskNone.");
|
||||
|
||||
return vtkm::exec::arg::ThreadIndicesTopologyMap<InputDomainType>(
|
||||
threadIndex3D, threadIndex1D, connectivity, globalThreadIndexOffset);
|
||||
}
|
||||
|
||||
/// Default version
|
||||
VTKM_SUPPRESS_EXEC_WARNINGS
|
||||
template <typename OutToInArrayType,
|
||||
typename VisitArrayType,
|
||||
typename ThreadToOutArrayType,
|
||||
typename InputDomainType,
|
||||
bool S = IsScatterIdentity,
|
||||
bool M = IsMaskNone>
|
||||
VTKM_EXEC EnableFnWhen<!(S && M), vtkm::exec::arg::ThreadIndicesTopologyMap<InputDomainType>>
|
||||
GetThreadIndices(vtkm::Id threadIndex1D,
|
||||
const vtkm::Id3& threadIndex3D,
|
||||
const OutToInArrayType& outToIn,
|
||||
const VisitArrayType& visit,
|
||||
const ThreadToOutArrayType& threadToOut,
|
||||
const InputDomainType& connectivity,
|
||||
vtkm::Id globalThreadIndexOffset = 0) const
|
||||
{
|
||||
const vtkm::Id outIndex = threadToOut.Get(threadIndex1D);
|
||||
return vtkm::exec::arg::ThreadIndicesTopologyMap<InputDomainType>(threadIndex3D,
|
||||
threadIndex1D,
|
||||
outToIn.Get(outIndex),
|
||||
visit.Get(outIndex),
|
||||
outIndex,
|
||||
connectivity,
|
||||
globalThreadIndexOffset);
|
||||
}
|
||||
};
|
||||
|
||||
/// Base class for worklets that map from Points to Cells.
|
||||
|
@ -197,12 +197,30 @@ public:
|
||||
globalThreadIndexOffset);
|
||||
}
|
||||
|
||||
|
||||
/// In the remaining methods and `constexpr` we determine at compilation time
|
||||
/// which method definition will be actually used for GetThreadIndices.
|
||||
///
|
||||
/// We want to avoid further function calls when we use WorkletMapTopology in which
|
||||
/// ScatterType is set as ScatterIdentity and MaskType as MaskNone.
|
||||
/// Otherwise, we call the default method defined at the bottom of this class.
|
||||
private:
|
||||
static constexpr bool IsScatterIdentity =
|
||||
std::is_same<ScatterType, vtkm::worklet::ScatterIdentity>::value;
|
||||
static constexpr bool IsMaskNone = std::is_same<MaskType, vtkm::worklet::MaskNone>::value;
|
||||
|
||||
public:
|
||||
template <bool Cond, typename ReturnType>
|
||||
using EnableFnWhen = typename std::enable_if<Cond, ReturnType>::type;
|
||||
|
||||
VTKM_SUPPRESS_EXEC_WARNINGS
|
||||
template <typename OutToInArrayType,
|
||||
typename VisitArrayType,
|
||||
typename ThreadToOutArrayType,
|
||||
typename InputDomainType>
|
||||
VTKM_EXEC vtkm::exec::arg::ThreadIndicesPointNeighborhood GetThreadIndices(
|
||||
typename InputDomainType,
|
||||
bool S = IsScatterIdentity,
|
||||
bool M = IsMaskNone>
|
||||
VTKM_EXEC EnableFnWhen<S && M, vtkm::exec::arg::ThreadIndicesPointNeighborhood> GetThreadIndices(
|
||||
vtkm::Id threadIndex1D,
|
||||
const vtkm::Id3& threadIndex3D,
|
||||
const OutToInArrayType& vtkmNotUsed(outToIn),
|
||||
@ -211,16 +229,35 @@ public:
|
||||
const InputDomainType& connectivity,
|
||||
vtkm::Id globalThreadIndexOffset = 0) const
|
||||
{
|
||||
using ScatterCheck = std::is_same<ScatterType, vtkm::worklet::ScatterIdentity>;
|
||||
VTKM_STATIC_ASSERT_MSG(ScatterCheck::value,
|
||||
"Scheduling on 3D topologies only works with default ScatterIdentity.");
|
||||
using MaskCheck = std::is_same<MaskType, vtkm::worklet::MaskNone>;
|
||||
VTKM_STATIC_ASSERT_MSG(MaskCheck::value,
|
||||
"Scheduling on 3D topologies only works with default MaskNone.");
|
||||
|
||||
return vtkm::exec::arg::ThreadIndicesPointNeighborhood(
|
||||
threadIndex3D, threadIndex1D, connectivity, globalThreadIndexOffset);
|
||||
}
|
||||
|
||||
VTKM_SUPPRESS_EXEC_WARNINGS
|
||||
template <typename OutToInArrayType,
|
||||
typename VisitArrayType,
|
||||
typename ThreadToOutArrayType,
|
||||
typename InputDomainType,
|
||||
bool S = IsScatterIdentity,
|
||||
bool M = IsMaskNone>
|
||||
VTKM_EXEC EnableFnWhen<!(S && M), vtkm::exec::arg::ThreadIndicesPointNeighborhood>
|
||||
GetThreadIndices(vtkm::Id threadIndex1D,
|
||||
const vtkm::Id3& threadIndex3D,
|
||||
const OutToInArrayType& outToIn,
|
||||
const VisitArrayType& visit,
|
||||
const ThreadToOutArrayType& threadToOut,
|
||||
const InputDomainType& connectivity,
|
||||
vtkm::Id globalThreadIndexOffset = 0) const
|
||||
{
|
||||
const vtkm::Id outIndex = threadToOut.Get(threadIndex1D);
|
||||
return vtkm::exec::arg::ThreadIndicesPointNeighborhood(threadIndex3D,
|
||||
threadIndex1D,
|
||||
outToIn.Get(outIndex),
|
||||
visit.Get(outIndex),
|
||||
outIndex,
|
||||
connectivity,
|
||||
globalThreadIndexOffset);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user