add CUDA implementation of ScanInclusiveByKey using Thrust library

This commit is contained in:
Li-Ta Lo 2017-04-14 11:25:25 -06:00
parent da8a2315ce
commit e77f9fac6a
3 changed files with 138 additions and 35 deletions

@ -664,6 +664,59 @@ private:
}
template<typename KeysPortal, typename ValuesPortal, typename OutputPortal>
VTKM_CONT static
typename ValuesPortal::ValueType ScanInclusiveByKeyPortal(const KeysPortal &keys,
const ValuesPortal &values,
const OutputPortal &output)
{
using KeyType = typename KeysPortal::ValueType;
typedef typename OutputPortal::ValueType ValueType;
return ScanInclusiveByKeyPortal(keys, values, output,
::thrust::equal_to<KeyType>(),
::thrust::plus<ValueType>());
}
template<typename KeysPortal, typename ValuesPortal, typename OutputPortal,
typename BinaryPredicate, typename AssociativeOperator>
VTKM_CONT static
typename ValuesPortal::ValueType ScanInclusiveByKeyPortal(const KeysPortal &keys,
const ValuesPortal &values,
const OutputPortal &output,
BinaryPredicate binary_predicate,
AssociativeOperator binary_operator)
{
typedef typename KeysPortal::ValueType KeyType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<KeyType,
BinaryPredicate> bpred(binary_predicate);
typedef typename OutputPortal::ValueType ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType,
AssociativeOperator> bop(binary_operator);
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType
IteratorType;
try
{
IteratorType end = ::thrust::inclusive_scan_by_key(thrust::cuda::par,
IteratorBegin(keys),
IteratorEnd(keys),
IteratorBegin(values),
IteratorBegin(output),
bpred,
bop);
return *(end-1);
}
catch(...)
{
throwAsVTKmException();
return typename ValuesPortal::ValueType();
}
//return the value at the last index in the array, as that is the sum
}
template<class ValuesPortal>
VTKM_CONT static void SortPortal(const ValuesPortal &values)
{
@ -1119,6 +1172,57 @@ public:
binary_functor);
}
template<typename T, typename U, typename KIn, typename VIn, typename VOut>
VTKM_CONT static T ScanInclusiveByKey(
const vtkm::cont::ArrayHandle<T, KIn>& keys,
const vtkm::cont::ArrayHandle<U, VIn>& values,
vtkm::cont::ArrayHandle<U, VOut>& output)
{
const vtkm::Id numberOfValues = keys.GetNumberOfValues();
if (numberOfValues <= 0)
{
output.PrepareForOutput(0, DeviceAdapterTag());
return vtkm::TypeTraits<T>::ZeroInitialization();
}
//We need call PrepareForInput on the input argument before invoking a
//function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place
//use case breaks.
keys.PrepareForInput(DeviceAdapterTag());
values.PrepareForInput(DeviceAdapterTag());
return ScanInclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
}
template<typename T, typename U, typename KIn, typename VIn, typename VOut,
typename BinaryFunctor>
VTKM_CONT static T ScanInclusiveByKey(
const vtkm::cont::ArrayHandle<T, KIn>& keys,
const vtkm::cont::ArrayHandle<U, VIn>& values,
vtkm::cont::ArrayHandle<U, VOut>& output,
BinaryFunctor binary_functor)
{
const vtkm::Id numberOfValues = keys.GetNumberOfValues();
if (numberOfValues <= 0)
{
output.PrepareForOutput(0, DeviceAdapterTag());
return vtkm::TypeTraits<T>::ZeroInitialization();
}
//We need call PrepareForInput on the input argument before invoking a
//function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place
//use case breaks.
keys.PrepareForInput(DeviceAdapterTag());
values.PrepareForInput(DeviceAdapterTag());
return ScanInclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
binary_functor);
}
// Because of some funny code conversions in nvcc, kernels for devices have to
// be public.
#ifndef VTKM_CUDA

@ -541,7 +541,8 @@ public:
const vtkm::cont::ArrayHandle<T,CIn>& input,
vtkm::cont::ArrayHandle<T,COut>& output)
{
return ScanExclusive(input, output, vtkm::Sum(),
// TODO: add DerivedAlgorithm?
return DerivedAlgorithm::ScanExclusive(input, output, vtkm::Sum(),
vtkm::TypeTraits<T>::ZeroInitialization());
}
@ -754,7 +755,6 @@ public:
scanOutput,
ReduceByKeyAdd<BinaryFunctor>(
binary_functor));
std::cout << scanOutput.GetNumberOfValues() << std::endl;
//at this point we are done with keystate, so free the memory
keystate.ReleaseResources();
DerivedAlgorithm::Copy(reducedValues, values_output);

@ -1324,44 +1324,43 @@ private:
IdArrayHandle valuesOut;
Algorithm::ScanInclusiveByKey(keys, values, valuesOut);
std::cout << valuesOut.GetNumberOfValues() << std::endl;
VTKM_TEST_ASSERT(valuesOut.GetNumberOfValues() == expectedLength,
"Got wrong number of output values");
for (auto i= 0; i < valuesOut.GetNumberOfValues(); i++) {
std::cout << valuesOut.GetPortalConstControl().Get(i) << " ";
for (auto i= 0; i < expectedLength; i++) {
const vtkm::Id v = valuesOut.GetPortalConstControl().Get(i);
VTKM_TEST_ASSERT(expectedValues[i] == v, "Incorrect scanned value");
}
std::cout << std::endl;
}
static VTKM_CONT void TestScanExclusiveByKey()
{
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Scan Inclusive By Key" << std::endl;
const vtkm::Id inputLength = 10;
vtkm::Id inputKeys[inputLength] = {0, 0, 0, 1, 1, 2, 3, 3, 3, 3};
vtkm::Id inputValues[inputLength] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
vtkm::Id init = 5;
const vtkm::Id expectedLength = 10;
vtkm::Id expectedValues[expectedLength] = {5, 6, 7, 5, 6, 5, 5, 6, 7, 8};
IdArrayHandle keys = vtkm::cont::make_ArrayHandle(inputKeys, inputLength);
IdArrayHandle values = vtkm::cont::make_ArrayHandle(inputValues, inputLength);
IdArrayHandle valuesOut;
Algorithm::ScanExclusiveByKey(keys, values, valuesOut, vtkm::Add(), init);
std::cout << valuesOut.GetNumberOfValues() << std::endl;
VTKM_TEST_ASSERT(valuesOut.GetNumberOfValues() == expectedLength,
"Got wrong number of output values");
for (auto i= 0; i < valuesOut.GetNumberOfValues(); i++) {
std::cout << valuesOut.GetPortalConstControl().Get(i) << " ";
}
std::cout << std::endl;
}
// static VTKM_CONT void TestScanExclusiveByKey()
// {
// std::cout << "-------------------------------------------" << std::endl;
// std::cout << "Testing Scan Exclusive By Key" << std::endl;
//
// const vtkm::Id inputLength = 10;
// vtkm::Id inputKeys[inputLength] = {0, 0, 0, 1, 1, 2, 3, 3, 3, 3};
// vtkm::Id inputValues[inputLength] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
// vtkm::Id init = 5;
//
// const vtkm::Id expectedLength = 10;
// vtkm::Id expectedValues[expectedLength] = {5, 6, 7, 5, 6, 5, 5, 6, 7, 8};
//
// IdArrayHandle keys = vtkm::cont::make_ArrayHandle(inputKeys, inputLength);
// IdArrayHandle values = vtkm::cont::make_ArrayHandle(inputValues, inputLength);
//
// IdArrayHandle valuesOut;
//
// Algorithm::ScanExclusiveByKey(keys, values, valuesOut, vtkm::Add(), init);
// std::cout << valuesOut.GetNumberOfValues() << std::endl;
// VTKM_TEST_ASSERT(valuesOut.GetNumberOfValues() == expectedLength,
// "Got wrong number of output values");
// for (auto i= 0; i < valuesOut.GetNumberOfValues(); i++) {
// std::cout << valuesOut.GetPortalConstControl().Get(i) << " ";
// }
// std::cout << std::endl;
//
// }
static VTKM_CONT void TestScanInclusive()
{
@ -1967,7 +1966,7 @@ private:
TestScanInclusiveByKey();
TestScanExclusiveByKey();
//TestScanExclusiveByKey();
TestSort();
TestSortWithComparisonObject();