mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-09-16 17:22:55 +00:00
add CUDA implementation of ScanInclusiveByKey using Thrust library
This commit is contained in:
parent
da8a2315ce
commit
e77f9fac6a
@ -664,6 +664,59 @@ private:
|
||||
|
||||
}
|
||||
|
||||
template<typename KeysPortal, typename ValuesPortal, typename OutputPortal>
|
||||
VTKM_CONT static
|
||||
typename ValuesPortal::ValueType ScanInclusiveByKeyPortal(const KeysPortal &keys,
|
||||
const ValuesPortal &values,
|
||||
const OutputPortal &output)
|
||||
{
|
||||
using KeyType = typename KeysPortal::ValueType;
|
||||
typedef typename OutputPortal::ValueType ValueType;
|
||||
return ScanInclusiveByKeyPortal(keys, values, output,
|
||||
::thrust::equal_to<KeyType>(),
|
||||
::thrust::plus<ValueType>());
|
||||
}
|
||||
|
||||
template<typename KeysPortal, typename ValuesPortal, typename OutputPortal,
|
||||
typename BinaryPredicate, typename AssociativeOperator>
|
||||
VTKM_CONT static
|
||||
typename ValuesPortal::ValueType ScanInclusiveByKeyPortal(const KeysPortal &keys,
|
||||
const ValuesPortal &values,
|
||||
const OutputPortal &output,
|
||||
BinaryPredicate binary_predicate,
|
||||
AssociativeOperator binary_operator)
|
||||
{
|
||||
typedef typename KeysPortal::ValueType KeyType;
|
||||
vtkm::exec::cuda::internal::WrappedBinaryOperator<KeyType,
|
||||
BinaryPredicate> bpred(binary_predicate);
|
||||
typedef typename OutputPortal::ValueType ValueType;
|
||||
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType,
|
||||
AssociativeOperator> bop(binary_operator);
|
||||
|
||||
|
||||
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType
|
||||
IteratorType;
|
||||
try
|
||||
{
|
||||
IteratorType end = ::thrust::inclusive_scan_by_key(thrust::cuda::par,
|
||||
IteratorBegin(keys),
|
||||
IteratorEnd(keys),
|
||||
IteratorBegin(values),
|
||||
IteratorBegin(output),
|
||||
bpred,
|
||||
bop);
|
||||
return *(end-1);
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
throwAsVTKmException();
|
||||
return typename ValuesPortal::ValueType();
|
||||
}
|
||||
|
||||
//return the value at the last index in the array, as that is the sum
|
||||
|
||||
}
|
||||
|
||||
template<class ValuesPortal>
|
||||
VTKM_CONT static void SortPortal(const ValuesPortal &values)
|
||||
{
|
||||
@ -1119,6 +1172,57 @@ public:
|
||||
binary_functor);
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename KIn, typename VIn, typename VOut>
|
||||
VTKM_CONT static T ScanInclusiveByKey(
|
||||
const vtkm::cont::ArrayHandle<T, KIn>& keys,
|
||||
const vtkm::cont::ArrayHandle<U, VIn>& values,
|
||||
vtkm::cont::ArrayHandle<U, VOut>& output)
|
||||
{
|
||||
const vtkm::Id numberOfValues = keys.GetNumberOfValues();
|
||||
if (numberOfValues <= 0)
|
||||
{
|
||||
output.PrepareForOutput(0, DeviceAdapterTag());
|
||||
return vtkm::TypeTraits<T>::ZeroInitialization();
|
||||
}
|
||||
|
||||
//We need call PrepareForInput on the input argument before invoking a
|
||||
//function. The order of execution of parameters of a function is undefined
|
||||
//so we need to make sure input is called before output, or else in-place
|
||||
//use case breaks.
|
||||
keys.PrepareForInput(DeviceAdapterTag());
|
||||
values.PrepareForInput(DeviceAdapterTag());
|
||||
return ScanInclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
|
||||
values.PrepareForInput(DeviceAdapterTag()),
|
||||
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename KIn, typename VIn, typename VOut,
|
||||
typename BinaryFunctor>
|
||||
VTKM_CONT static T ScanInclusiveByKey(
|
||||
const vtkm::cont::ArrayHandle<T, KIn>& keys,
|
||||
const vtkm::cont::ArrayHandle<U, VIn>& values,
|
||||
vtkm::cont::ArrayHandle<U, VOut>& output,
|
||||
BinaryFunctor binary_functor)
|
||||
{
|
||||
const vtkm::Id numberOfValues = keys.GetNumberOfValues();
|
||||
if (numberOfValues <= 0)
|
||||
{
|
||||
output.PrepareForOutput(0, DeviceAdapterTag());
|
||||
return vtkm::TypeTraits<T>::ZeroInitialization();
|
||||
}
|
||||
|
||||
//We need call PrepareForInput on the input argument before invoking a
|
||||
//function. The order of execution of parameters of a function is undefined
|
||||
//so we need to make sure input is called before output, or else in-place
|
||||
//use case breaks.
|
||||
keys.PrepareForInput(DeviceAdapterTag());
|
||||
values.PrepareForInput(DeviceAdapterTag());
|
||||
return ScanInclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
|
||||
values.PrepareForInput(DeviceAdapterTag()),
|
||||
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
|
||||
binary_functor);
|
||||
}
|
||||
|
||||
// Because of some funny code conversions in nvcc, kernels for devices have to
|
||||
// be public.
|
||||
#ifndef VTKM_CUDA
|
||||
|
@ -541,7 +541,8 @@ public:
|
||||
const vtkm::cont::ArrayHandle<T,CIn>& input,
|
||||
vtkm::cont::ArrayHandle<T,COut>& output)
|
||||
{
|
||||
return ScanExclusive(input, output, vtkm::Sum(),
|
||||
// TODO: add DerivedAlgorithm?
|
||||
return DerivedAlgorithm::ScanExclusive(input, output, vtkm::Sum(),
|
||||
vtkm::TypeTraits<T>::ZeroInitialization());
|
||||
}
|
||||
|
||||
@ -754,7 +755,6 @@ public:
|
||||
scanOutput,
|
||||
ReduceByKeyAdd<BinaryFunctor>(
|
||||
binary_functor));
|
||||
std::cout << scanOutput.GetNumberOfValues() << std::endl;
|
||||
//at this point we are done with keystate, so free the memory
|
||||
keystate.ReleaseResources();
|
||||
DerivedAlgorithm::Copy(reducedValues, values_output);
|
||||
|
@ -1324,44 +1324,43 @@ private:
|
||||
IdArrayHandle valuesOut;
|
||||
|
||||
Algorithm::ScanInclusiveByKey(keys, values, valuesOut);
|
||||
std::cout << valuesOut.GetNumberOfValues() << std::endl;
|
||||
VTKM_TEST_ASSERT(valuesOut.GetNumberOfValues() == expectedLength,
|
||||
"Got wrong number of output values");
|
||||
for (auto i= 0; i < valuesOut.GetNumberOfValues(); i++) {
|
||||
std::cout << valuesOut.GetPortalConstControl().Get(i) << " ";
|
||||
for (auto i= 0; i < expectedLength; i++) {
|
||||
const vtkm::Id v = valuesOut.GetPortalConstControl().Get(i);
|
||||
VTKM_TEST_ASSERT(expectedValues[i] == v, "Incorrect scanned value");
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
}
|
||||
|
||||
static VTKM_CONT void TestScanExclusiveByKey()
|
||||
{
|
||||
std::cout << "-------------------------------------------" << std::endl;
|
||||
std::cout << "Testing Scan Inclusive By Key" << std::endl;
|
||||
|
||||
const vtkm::Id inputLength = 10;
|
||||
vtkm::Id inputKeys[inputLength] = {0, 0, 0, 1, 1, 2, 3, 3, 3, 3};
|
||||
vtkm::Id inputValues[inputLength] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
vtkm::Id init = 5;
|
||||
|
||||
const vtkm::Id expectedLength = 10;
|
||||
vtkm::Id expectedValues[expectedLength] = {5, 6, 7, 5, 6, 5, 5, 6, 7, 8};
|
||||
|
||||
IdArrayHandle keys = vtkm::cont::make_ArrayHandle(inputKeys, inputLength);
|
||||
IdArrayHandle values = vtkm::cont::make_ArrayHandle(inputValues, inputLength);
|
||||
|
||||
IdArrayHandle valuesOut;
|
||||
|
||||
Algorithm::ScanExclusiveByKey(keys, values, valuesOut, vtkm::Add(), init);
|
||||
std::cout << valuesOut.GetNumberOfValues() << std::endl;
|
||||
VTKM_TEST_ASSERT(valuesOut.GetNumberOfValues() == expectedLength,
|
||||
"Got wrong number of output values");
|
||||
for (auto i= 0; i < valuesOut.GetNumberOfValues(); i++) {
|
||||
std::cout << valuesOut.GetPortalConstControl().Get(i) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
}
|
||||
// static VTKM_CONT void TestScanExclusiveByKey()
|
||||
// {
|
||||
// std::cout << "-------------------------------------------" << std::endl;
|
||||
// std::cout << "Testing Scan Exclusive By Key" << std::endl;
|
||||
//
|
||||
// const vtkm::Id inputLength = 10;
|
||||
// vtkm::Id inputKeys[inputLength] = {0, 0, 0, 1, 1, 2, 3, 3, 3, 3};
|
||||
// vtkm::Id inputValues[inputLength] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
// vtkm::Id init = 5;
|
||||
//
|
||||
// const vtkm::Id expectedLength = 10;
|
||||
// vtkm::Id expectedValues[expectedLength] = {5, 6, 7, 5, 6, 5, 5, 6, 7, 8};
|
||||
//
|
||||
// IdArrayHandle keys = vtkm::cont::make_ArrayHandle(inputKeys, inputLength);
|
||||
// IdArrayHandle values = vtkm::cont::make_ArrayHandle(inputValues, inputLength);
|
||||
//
|
||||
// IdArrayHandle valuesOut;
|
||||
//
|
||||
// Algorithm::ScanExclusiveByKey(keys, values, valuesOut, vtkm::Add(), init);
|
||||
// std::cout << valuesOut.GetNumberOfValues() << std::endl;
|
||||
// VTKM_TEST_ASSERT(valuesOut.GetNumberOfValues() == expectedLength,
|
||||
// "Got wrong number of output values");
|
||||
// for (auto i= 0; i < valuesOut.GetNumberOfValues(); i++) {
|
||||
// std::cout << valuesOut.GetPortalConstControl().Get(i) << " ";
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
//
|
||||
// }
|
||||
|
||||
static VTKM_CONT void TestScanInclusive()
|
||||
{
|
||||
@ -1967,7 +1966,7 @@ private:
|
||||
|
||||
TestScanInclusiveByKey();
|
||||
|
||||
TestScanExclusiveByKey();
|
||||
//TestScanExclusiveByKey();
|
||||
|
||||
TestSort();
|
||||
TestSortWithComparisonObject();
|
||||
|
Loading…
Reference in New Issue
Block a user