Fix MergeCombinedOtherStartIndexWorklet instead of using STL; clean up

This commit is contained in:
Gunther H. Weber 2021-03-18 16:13:32 -07:00
parent b1ace58809
commit 4403947746
3 changed files with 57 additions and 92 deletions

@ -701,50 +701,12 @@ inline void ContourTreeMesh<FieldType>::MergeWith(ContourTreeMesh<FieldType>& ot
<< ": " << timer.GetElapsedTime() << " seconds" << std::endl;
timer.Start();
// TODO VTKM -Version MergedCombinedOtherStartIndex. Replace 1r block with the 1s block. Need to check for Segfault in contourtree_mesh_inc_ns::MergeCombinedOtherStartIndexWorklet. This workloat also still uses a number of stl algorithms that should be replaced with VTKm code (which is porbably also why the worklet fails).
/* // 1s--start
contourtree_mesh_inc_ns::MergeCombinedOtherStartIndexWorklet<DeviceAdapter> mergeCombinedOtherStartIndexWorklet;
vtkm::worklet::DispatcherMapField< contourtree_mesh_inc_ns::MergeCombinedOtherStartIndexWorklet<DeviceAdapter>> mergeCombinedOtherStartIndexDispatcher(mergeCombinedOtherStartIndexWorklet);
this->Invoke(mergeCombinedOtherStartIndexWorklet,
combinedOtherStartIndex, // (input/output)
combinedNeighbours, // (input/output)
combinedFirstNeighbour // (input)
);
// 1s--end
*/
// TODO Fix the MergedCombinedOtherStartIndex worklet and remove //1r block below
// 1r--start
auto combinedOtherStartIndexPortal = combinedOtherStartIndex.WritePortal();
auto combinedFirstNeighbourPortal = combinedFirstNeighbour.ReadPortal();
auto combinedNeighboursPortal = combinedNeighbours.WritePortal();
std::vector<vtkm::Id> tempCombinedNeighours((std::size_t)combinedNeighbours.GetNumberOfValues());
for (vtkm::Id vtx = 0; vtx < combinedNeighbours.GetNumberOfValues(); ++vtx)
{
tempCombinedNeighours[(std::size_t)vtx] = combinedNeighboursPortal.Get(vtx);
}
for (vtkm::Id vtx = 0; vtx < combinedFirstNeighbour.GetNumberOfValues(); ++vtx)
{
if (combinedOtherStartIndexPortal.Get(vtx)) // Needs merge
{
auto neighboursBegin = tempCombinedNeighours.begin() + combinedFirstNeighbourPortal.Get(vtx);
auto neighboursEnd = (vtx < combinedFirstNeighbour.GetNumberOfValues() - 1)
? tempCombinedNeighours.begin() + combinedFirstNeighbourPortal.Get(vtx + 1)
: tempCombinedNeighours.end();
std::inplace_merge(
neighboursBegin, neighboursBegin + combinedOtherStartIndexPortal.Get(vtx), neighboursEnd);
auto it = std::unique(neighboursBegin, neighboursEnd);
combinedOtherStartIndexPortal.Set(vtx, static_cast<vtkm::Id>(neighboursEnd - it));
while (it != neighboursEnd)
*(it++) = NO_SUCH_ELEMENT;
}
}
// copy the values back to VTKm
for (vtkm::Id vtx = 0; vtx < combinedNeighbours.GetNumberOfValues(); ++vtx)
{
combinedNeighboursPortal.Set(vtx, tempCombinedNeighours[(std::size_t)vtx]);
}
// 1r--end
contourtree_mesh_inc_ns::MergeCombinedOtherStartIndexWorklet mergeCombinedOtherStartIndexWorklet;
this->Invoke(mergeCombinedOtherStartIndexWorklet,
combinedOtherStartIndex, // (input/output)
combinedNeighbours, // (input/output)
combinedFirstNeighbour // (input)
);
timingsStream << " " << std::setw(38) << std::left << "Merge CombinedOtherStartIndex"
<< ": " << timer.GetElapsedTime() << " seconds" << std::endl;

@ -81,7 +81,6 @@ namespace contourtree_augmented
namespace mesh_dem_contourtree_mesh_inc
{
template <typename DeviceAdapter>
class MergeCombinedOtherStartIndexWorklet : public vtkm::worklet::WorkletMapField
{
public:
@ -104,23 +103,57 @@ public:
const InOutFieldPortalType combinedNeighboursPortal,
const InFieldPortalType combinedFirstNeighbourPortal) const
{
// TODO Replace this to not use stl algorithms inside the worklet
if (combinedOtherStartIndexPortal.Get(vtx)) // Needs merge
{
vtkm::cont::ArrayPortalToIterators<InOutFieldPortalType> combinedNeighboursIterators(
combinedNeighboursPortal);
auto neighboursBegin =
combinedNeighboursIterators.GetBegin() + combinedFirstNeighbourPortal.Get(vtx);
auto neighboursEnd = (vtx < combinedFirstNeighbourPortal.GetNumberOfValues() - 1)
? combinedNeighboursIterators.GetBegin() + combinedFirstNeighbourPortal.Get(vtx + 1)
: combinedNeighboursIterators.GetEnd();
std::inplace_merge(
neighboursBegin, neighboursBegin + combinedOtherStartIndexPortal.Get(vtx), neighboursEnd);
auto it = std::unique(neighboursBegin, neighboursEnd);
combinedOtherStartIndexPortal.Set(vtx, neighboursEnd - it);
while (it != neighboursEnd)
auto myNeighboursBeginIdx = combinedFirstNeighbourPortal.Get(vtx);
auto otherNeighboursBeginIdx = myNeighboursBeginIdx + combinedOtherStartIndexPortal.Get(vtx);
auto neighboursEndIdx = (vtx < combinedFirstNeighbourPortal.GetNumberOfValues() - 1)
? combinedFirstNeighbourPortal.Get(vtx + 1) - 1
: combinedNeighboursPortal.GetNumberOfValues() - 1;
// Merge two sorted neighbours lists from myNeighboursBeginIdx through
// otherNeighboursBeginIdx - 1 and from otherNeighboursBeginIdx - 1 to
// neighboursEndIdx
auto arr0_end = otherNeighboursBeginIdx - 1;
auto curr = neighboursEndIdx;
while (curr >= otherNeighboursBeginIdx)
{
*(it++) = NO_SUCH_ELEMENT;
auto x = combinedNeighboursPortal.Get(curr);
if (x < combinedNeighboursPortal.Get(arr0_end))
{
combinedNeighboursPortal.Set(curr, combinedNeighboursPortal.Get(arr0_end));
auto pos = arr0_end - 1;
while (pos >= myNeighboursBeginIdx && (combinedNeighboursPortal.Get(pos) > x))
{
combinedNeighboursPortal.Set(pos + 1, combinedNeighboursPortal.Get(pos));
--pos;
}
combinedNeighboursPortal.Set(pos + 1, x);
}
--curr;
}
// Remove duplicates
vtkm::Id prevNeighbour = combinedNeighboursPortal.Get(myNeighboursBeginIdx);
vtkm::Id currPos = myNeighboursBeginIdx + 1;
for (vtkm::Id i = myNeighboursBeginIdx + 1; i <= neighboursEndIdx; ++i)
{
auto currNeighbour = combinedNeighboursPortal.Get(i);
if (currNeighbour != prevNeighbour)
{
combinedNeighboursPortal.Set(currPos++, currNeighbour);
prevNeighbour = currNeighbour;
}
}
// Record number of elements in neighbour list for subsequent compression
combinedOtherStartIndexPortal.Set(vtx, neighboursEndIdx + 1 - currPos);
// Fill remainder with NO_SUCH_ELEMENT so that it can be easily discarded
while (currPos != neighboursEndIdx + 1)
{
combinedNeighboursPortal.Set(currPos++, (vtkm::Id)NO_SUCH_ELEMENT);
}
}
@ -136,38 +169,9 @@ public:
combinedOtherStartIndex[vtx] = neighboursEnd - it;
while (it != neighboursEnd) *(it++) = NO_SUCH_ELEMENT;
}
}*/
/* Attempt at porting the code without using STL
if (combinedOtherStartIndexPortal.Get(vtx))
{
vtkm::Id combinedNeighboursBeginIndex = combinedFirstNeighbourPortal.Get(vtx);
vtkm::Id combinedNeighboursEndIndex = (vtx < combinedFirstNeighbourPortal.GetNumberOfValues() - 1) ? combinedFirstNeighbourPortal.Get(vtx+1) : combinedNeighboursPortal.GetNumberOfValues() -1;
vtkm::Id numSelectedVals = combinedNeighboursEndIndex- combinedNeighboursBeginIndex + 1;
vtkm::cont::ArrayHandleCounting <vtkm::Id > selectSubRangeIndex (combinedNeighboursBeginIndex, 1, numSelectedVals);
vtkm::cont::ArrayHandlePermutation<vtkm::cont::ArrayHandleCounting <vtkm::Id >, IdArrayType> selectSubRangeArrayHandle(
selectSubRangeIndex, // index array to select the range of values
combinedNeighboursPortal // value array to select from. // TODO this won't work because this is an ArrayPortal not an ArrayHandle
);
vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>::Sort(selectSubRangeArrayHandle);
vtkm::Id numUniqueVals = 1;
for(vtkm::Id i=combinedNeighboursBeginIndex; i<=combinedNeighboursEndIndex; i++){
if (combinedNeighboursPortal.Get(i) == combinedNeighboursPortal.Get(i-1))
{
combinedNeighboursPortal.Set(i, (vtkm::Id) NO_SUCH_ELEMENT);
}
else
{
numUniqueVals += 1;
}
}
combinedOtherStartIndexPortal.Set(vtx, combinedNeighboursEndIndex - (combinedNeighboursBeginIndex + numUniqueVals + 1));
}
*/
}
*/
}
}; // MergeCombinedOtherStartIndexWorklet

@ -662,8 +662,7 @@ public:
}
else
{
// This should just be vtkm::worklet::contourtree_augmented::NO_SUCH_ELEMENT, but there's the build error - identifier "vtkm::worklet::contourtree_augmented::NO_SUCH_ELEMENT" is undefined in device code
chainToBranchPortal.Set(supernode, std::numeric_limits<vtkm::Id>::min());
chainToBranchPortal.Set(supernode, (vtkm::Id)NO_SUCH_ELEMENT);
}
}
}; // ComputeMinMaxValues