When iteratively computing output map in count scatter, compute visit

It is the case that there are two ways to create the output to input map in
a count scatter. The first is to use a parallel find for every output index.
The second, which is used when there are lots of output, is to iterate over
the input and write out the reverse map. In this case, it is trivial to also
write out the visit indices, so do that instead of a bunch more searches.
This commit is contained in:
Kenneth Moreland 2015-10-30 09:33:10 -06:00
parent 45abbb5c75
commit 8ab2938b8c

@ -43,20 +43,26 @@ struct ReverseInputToOutputMapKernel : vtkm::exec::FunctorBase
typedef typename
vtkm::cont::ArrayHandle<vtkm::Id>::ExecutionTypes<Device>::PortalConst
InputMapType;
typedef typename
vtkm::cont::ArrayHandle<vtkm::Id>::ExecutionTypes<Device>::Portal
OutputMapType;
typedef typename
vtkm::cont::ArrayHandle<vtkm::Id>::ExecutionTypes<Device>::Portal
OutputMapType;
typedef typename
vtkm::cont::ArrayHandle<vtkm::IdComponent>::ExecutionTypes<Device>::Portal
VisitType;
InputMapType InputToOutputMap;
OutputMapType OutputToInputMap;
VisitType Visit;
vtkm::Id OutputSize;
VTKM_CONT_EXPORT
ReverseInputToOutputMapKernel(const InputMapType &inputToOutputMap,
const OutputMapType &outputToInputMap,
const VisitType &visit,
vtkm::Id outputSize)
: InputToOutputMap(inputToOutputMap),
OutputToInputMap(outputToInputMap),
Visit(visit),
OutputSize(outputSize)
{ }
@ -74,11 +80,14 @@ struct ReverseInputToOutputMapKernel : vtkm::exec::FunctorBase
}
vtkm::Id outputEndIndex = this->InputToOutputMap.Get(inputIndex);
vtkm::IdComponent visitIndex = 0;
for (vtkm::Id outputIndex = outputStartIndex;
outputIndex < outputEndIndex;
outputIndex++)
{
this->OutputToInputMap.Set(outputIndex, inputIndex);
this->Visit.Set(outputIndex, visitIndex);
visitIndex++;
}
}
};
@ -219,11 +228,6 @@ private:
this->BuildOutputToInputMapWithIterate(
outputSize, inputToOutputMap, Device());
}
// This builds the visit indices using a parallel find. A prefix sum by
// key could be more efficient, but that is not implemented in the device
// adapter at the time of this writing.
this->BuildVisitArrayWithFind(Device());
}
template<typename Device>
@ -236,6 +240,20 @@ private:
vtkm::cont::ArrayHandleIndex outputIndices(outputSize);
vtkm::cont::DeviceAdapterAlgorithm<Device>::UpperBounds(
inputToOutputMap, outputIndices, this->OutputToInputMap);
// Do not need this anymore.
inputToOutputMap.ReleaseResources();
vtkm::cont::ArrayHandle<vtkm::Id> startsOfGroups;
// This find gives the index of the start of a group.
vtkm::cont::DeviceAdapterAlgorithm<Device>::LowerBounds(
this->OutputToInputMap, this->OutputToInputMap, startsOfGroups);
detail::SubtractToVisitIndexKernel<Device>
kernel(startsOfGroups.PrepareForInput(Device()),
this->VisitArray.PrepareForOutput(outputSize, Device()));
vtkm::cont::DeviceAdapterAlgorithm<Device>::Schedule(kernel, outputSize);
}
template<typename Device>
@ -248,29 +266,12 @@ private:
detail::ReverseInputToOutputMapKernel<Device>
kernel(inputToOutputMap.PrepareForInput(Device()),
this->OutputToInputMap.PrepareForOutput(outputSize, Device()),
this->VisitArray.PrepareForOutput(outputSize, Device()),
outputSize);
vtkm::cont::DeviceAdapterAlgorithm<Device>::Schedule(
kernel, inputToOutputMap.GetNumberOfValues());
}
template<typename Device>
VTKM_CONT_EXPORT
void BuildVisitArrayWithFind(Device)
{
vtkm::cont::ArrayHandle<vtkm::Id> startsOfGroups;
// This find gives the index of the start of a group.
vtkm::cont::DeviceAdapterAlgorithm<Device>::LowerBounds(
this->OutputToInputMap, this->OutputToInputMap, startsOfGroups);
vtkm::Id outputSize = this->OutputToInputMap.GetNumberOfValues();
detail::SubtractToVisitIndexKernel<Device>
kernel(startsOfGroups.PrepareForInput(Device()),
this->VisitArray.PrepareForOutput(outputSize, Device()));
vtkm::cont::DeviceAdapterAlgorithm<Device>::Schedule(kernel, outputSize);
}
};
}