Do not overthink ComputeNumberOfBlocksPerAxis

This commit is contained in:
Gunther H. Weber 2022-09-23 18:00:03 -07:00
parent e2968ca244
commit 49dd19408f

@ -154,48 +154,32 @@ private:
std::vector<std::string> mCLOptions;
};
inline vtkm::IdComponent FindSplitAxis(vtkm::Id3 globalSize)
{
vtkm::IdComponent splitAxis = 0;
for (vtkm::IdComponent d = 1; d < 3; ++d)
{
if (globalSize[d] > globalSize[splitAxis])
{
splitAxis = d;
}
}
return splitAxis;
}
inline vtkm::Id3 ComputeNumberOfBlocksPerAxis(vtkm::Id3 globalSize, vtkm::Id numberOfBlocks)
{
// Split numberOfBlocks into a power of two and a remainder
vtkm::Id powerOfTwoPortion = 1;
while (numberOfBlocks % 2 == 0)
{
powerOfTwoPortion *= 2;
numberOfBlocks /= 2;
}
vtkm::Id currNumberOfBlocks = numberOfBlocks;
vtkm::Id3 blocksPerAxis{ 1, 1, 1 };
if (numberOfBlocks > 1)
while (currNumberOfBlocks > 1)
{
// Split the longest axis according to remainder
vtkm::IdComponent splitAxis = FindSplitAxis(globalSize);
blocksPerAxis[splitAxis] = numberOfBlocks;
globalSize[splitAxis] /= numberOfBlocks;
vtkm::IdComponent splitAxis = 0;
for (vtkm::IdComponent d = 1; d < 3; ++d)
{
if (globalSize[d] > globalSize[splitAxis])
{
splitAxis = d;
}
}
if (currNumberOfBlocks % 2 == 0)
{
blocksPerAxis[splitAxis] *= 2;
globalSize[splitAxis] /= 2;
currNumberOfBlocks /= 2;
}
else
{
blocksPerAxis[splitAxis] *= currNumberOfBlocks;
break;
}
}
// Now perform splits for the power of two remainder of numberOfBlocks
while (powerOfTwoPortion > 1)
{
vtkm::IdComponent splitAxis = FindSplitAxis(globalSize);
VTKM_ASSERT(globalSize[splitAxis] > 1);
blocksPerAxis[splitAxis] *= 2;
globalSize[splitAxis] /= 2;
powerOfTwoPortion /= 2;
}
return blocksPerAxis;
}
@ -229,8 +213,8 @@ inline vtkm::cont::DataSet CreateSubDataSet(const vtkm::cont::DataSet& ds,
const std::string& fieldName)
{
vtkm::Id3 globalSize;
ds.GetCellSet().CastAndCall(vtkm::worklet::contourtree_augmented::GetPointDimensions(),
globalSize);
ds.GetCellSet().CastAndCallForTypes<VTKM_DEFAULT_CELL_SET_LIST_STRUCTURED>(
vtkm::worklet::contourtree_augmented::GetPointDimensions(), globalSize);
const vtkm::Id nOutValues = blockSize[0] * blockSize[1] * blockSize[2];
const auto inDataArrayHandle = ds.GetPointField(fieldName).GetData();