add both generic and Thrust ScanExclusiveByKey

This commit is contained in:
Li-Ta Lo 2017-04-17 15:03:49 -06:00
parent e77f9fac6a
commit 7023266585
4 changed files with 223 additions and 40 deletions

@ -717,6 +717,62 @@ private:
}
template<typename KeysPortal, typename ValuesPortal, typename OutputPortal>
VTKM_CONT static
typename ValuesPortal::ValueType ScanExclusiveByKeyPortal(const KeysPortal &keys,
const ValuesPortal &values,
const OutputPortal &output)
{
using KeyType = typename KeysPortal::ValueType;
typedef typename OutputPortal::ValueType ValueType;
return ScanExclusiveByKeyPortal(keys, values, output,
vtkm::TypeTraits<ValueType>::ZeroInitialization(),
::thrust::equal_to<KeyType>(),
::thrust::plus<ValueType>());
}
template<typename KeysPortal, typename ValuesPortal, typename OutputPortal, typename T,
typename BinaryPredicate, typename AssociativeOperator>
VTKM_CONT static
typename ValuesPortal::ValueType ScanExclusiveByKeyPortal(const KeysPortal &keys,
const ValuesPortal &values,
const OutputPortal &output,
T initValue,
BinaryPredicate binary_predicate,
AssociativeOperator binary_operator)
{
typedef typename KeysPortal::ValueType KeyType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<KeyType,
BinaryPredicate> bpred(binary_predicate);
typedef typename OutputPortal::ValueType ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType,
AssociativeOperator> bop(binary_operator);
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType
IteratorType;
try
{
IteratorType end = ::thrust::exclusive_scan_by_key(thrust::cuda::par,
IteratorBegin(keys),
IteratorEnd(keys),
IteratorBegin(values),
IteratorBegin(output),
initValue,
bpred,
bop);
return *(end-1);
}
catch(...)
{
throwAsVTKmException();
return typename ValuesPortal::ValueType();
}
//return the value at the last index in the array, as that is the sum
}
template<class ValuesPortal>
VTKM_CONT static void SortPortal(const ValuesPortal &values)
{
@ -1217,12 +1273,66 @@ public:
//use case breaks.
keys.PrepareForInput(DeviceAdapterTag());
values.PrepareForInput(DeviceAdapterTag());
return ScanInclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
return ScanExclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
binary_functor);
}
template<typename T, typename U, typename KIn, typename VIn, typename VOut>
VTKM_CONT static T ScanExclusiveByKey(
const vtkm::cont::ArrayHandle<T, KIn>& keys,
const vtkm::cont::ArrayHandle<U, VIn>& values,
vtkm::cont::ArrayHandle<U, VOut>& output)
{
const vtkm::Id numberOfValues = keys.GetNumberOfValues();
if (numberOfValues <= 0)
{
output.PrepareForOutput(0, DeviceAdapterTag());
return vtkm::TypeTraits<T>::ZeroInitialization();
}
//We need call PrepareForInput on the input argument before invoking a
//function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place
//use case breaks.
keys.PrepareForInput(DeviceAdapterTag());
values.PrepareForInput(DeviceAdapterTag());
return ScanExnclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
vtkm::TypeTraits<T>::ZeroInitialization());
}
template<typename T, typename U, typename KIn, typename VIn, typename VOut,
typename BinaryFunctor>
VTKM_CONT static T ScanExclusiveByKey(
const vtkm::cont::ArrayHandle<T, KIn>& keys,
const vtkm::cont::ArrayHandle<U, VIn>& values,
vtkm::cont::ArrayHandle<U, VOut>& output,
BinaryFunctor binary_functor,
const U& initialValue)
{
const vtkm::Id numberOfValues = keys.GetNumberOfValues();
if (numberOfValues <= 0)
{
output.PrepareForOutput(0, DeviceAdapterTag());
return vtkm::TypeTraits<T>::ZeroInitialization();
}
//We need call PrepareForInput on the input argument before invoking a
//function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place
//use case breaks.
keys.PrepareForInput(DeviceAdapterTag());
values.PrepareForInput(DeviceAdapterTag());
return ScanExclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
initialValue,
::thrust::equal_to<T>(),
binary_functor);
}
// Because of some funny code conversions in nvcc, kernels for devices have to
// be public.
#ifndef VTKM_CUDA

@ -551,20 +551,62 @@ public:
// Scan Exclusive By Key
template<typename T, typename U, typename KIn, typename VIn, typename VOut,
class BinaryFunctor>
VTKM_CONT static T ScanExclusiveByKey(
VTKM_CONT static void ScanExclusiveByKey(
const vtkm::cont::ArrayHandle<T, KIn>& keys,
const vtkm::cont::ArrayHandle<U, VIn>& values,
vtkm::cont::ArrayHandle<U ,VOut>& output,
BinaryFunctor binaryFunctor,
const T& initialValue)
const U& initialValue)
{
// TODO: add DerivedAlgorithm?
ScanInclusiveByKey(keys, values, output, binaryFunctor);
// 0. TODO: special case for 1 element input?
vtkm::Id numberOfKeys = keys.GetNumberOfValues();
// 1. Create head flags
//we need to determine based on the keys what is the keystate for
//each key. The states are start, middle, end of a series and the special
//state start and end of a series
vtkm::cont::ArrayHandle< ReduceKeySeriesStates > keystate;
{
typedef typename vtkm::cont::ArrayHandle<T,KIn>::template ExecutionTypes<DeviceAdapterTag>
::PortalConst InputPortalType;
typedef typename vtkm::cont::ArrayHandle< ReduceKeySeriesStates >::template ExecutionTypes<DeviceAdapterTag>
::Portal KeyStatePortalType;
InputPortalType inputPortal = keys.PrepareForInput(DeviceAdapterTag());
KeyStatePortalType keyStatePortal = keystate.PrepareForOutput(numberOfKeys,
DeviceAdapterTag());
ReduceStencilGeneration<InputPortalType, KeyStatePortalType> kernel(inputPortal, keyStatePortal);
DerivedAlgorithm::Schedule(kernel, numberOfKeys);
}
// 2. Shift input and initialize elements at head flags position to initValue
typedef typename vtkm::cont::ArrayHandle<T,vtkm::cont::StorageTagBasic> TempArrayType;
typedef typename vtkm::cont::ArrayHandle<T,vtkm::cont::StorageTagBasic>::template ExecutionTypes<DeviceAdapterTag>::Portal TempPortalType;
TempArrayType temp;
{
typedef typename vtkm::cont::ArrayHandle<T,KIn>::template ExecutionTypes<DeviceAdapterTag>
::PortalConst InputPortalType;
typedef typename vtkm::cont::ArrayHandle< ReduceKeySeriesStates >::template ExecutionTypes<DeviceAdapterTag>
::PortalConst KeyStatePortalType;
InputPortalType inputPortal = values.PrepareForInput(DeviceAdapterTag());
KeyStatePortalType keyStatePortal = keystate.PrepareForInput(DeviceAdapterTag());
TempPortalType tempPortal = temp.PrepareForOutput(numberOfKeys,
DeviceAdapterTag());
ShiftCopyAndInit<U, InputPortalType, KeyStatePortalType, TempPortalType>
kernel(inputPortal, keyStatePortal, tempPortal, initialValue);
DerivedAlgorithm::Schedule(kernel, numberOfKeys);
}
// 3. Perform an ScanInclusiveByKey
DerivedAlgorithm::ScanInclusiveByKey(keys, temp, output, binaryFunctor);
}
template<typename T, typename U, class KIn, typename VIn, typename VOut>
VTKM_CONT static T ScanExclusiveByKey(
VTKM_CONT static void ScanExclusiveByKey(
const vtkm::cont::ArrayHandle<T, KIn>& keys,
const vtkm::cont::ArrayHandle<U, VIn>& values,
vtkm::cont::ArrayHandle<U, VOut>& output)

@ -298,6 +298,34 @@ struct ReduceByKeyUnaryStencilOp
}
};
template <typename T, typename InputPortalType,
typename KeyStatePortalType, typename OutputPortalType>
struct ShiftCopyAndInit : vtkm::exec::FunctorBase
{
InputPortalType Input;
KeyStatePortalType KeyState;
OutputPortalType Output;
T initValue;
ShiftCopyAndInit(const InputPortalType& _input,
const KeyStatePortalType &kstate,
OutputPortalType& _output,
T _init) : Input(_input),
KeyState(kstate),
Output(_output),
initValue(_init) {}
void operator()(vtkm::Id index) const
{
if (this->KeyState.Get(index).fStart) {
Output.Set(index, initValue);
} else {
Output.Set(index, Input.Get(index-1));
}
}
};
template<class InputPortalType, class OutputPortalType>
struct CopyKernel
{
@ -808,23 +836,27 @@ struct InclusiveToExclusiveKernel : vtkm::exec::FunctorBase
}
};
template <typename InPortalType, typename OutPortalType, typename BinaryFunctor>
template <typename InPortalType, typename OutPortalType, typename StencilPortalType,
typename BinaryFunctor>
struct InclusiveToExclusiveByKeyKernel : vtkm::exec::FunctorBase
{
typedef typename InPortalType::ValueType ValueType;
InPortalType InPortal;
OutPortalType OutPortal;
StencilPortalType StencilPortal;
BinaryFunctor BinaryOperator;
ValueType InitialValue;
VTKM_CONT
InclusiveToExclusiveByKeyKernel(const InPortalType &inPortal,
const OutPortalType &outPortal,
BinaryFunctor &binaryOperator,
ValueType initialValue)
const OutPortalType &outPortal,
const StencilPortalType &stencilPortal,
BinaryFunctor &binaryOperator,
ValueType initialValue)
: InPortal(inPortal),
OutPortal(outPortal),
StencilPortal(stencilPortal),
BinaryOperator(binaryOperator),
InitialValue(initialValue)
{ }
@ -833,7 +865,7 @@ struct InclusiveToExclusiveByKeyKernel : vtkm::exec::FunctorBase
VTKM_EXEC
void operator()(vtkm::Id index) const
{
ValueType result = (index == 0) ? this->InitialValue :
ValueType result = (this->StencilPortal.Get(index).fState == true) ? this->InitialValue :
this->BinaryOperator(this->InitialValue, this->InPortal.Get(index - 1));
this->OutPortal.Set(index, result);
}

@ -1333,34 +1333,33 @@ private:
}
// static VTKM_CONT void TestScanExclusiveByKey()
// {
// std::cout << "-------------------------------------------" << std::endl;
// std::cout << "Testing Scan Exclusive By Key" << std::endl;
//
// const vtkm::Id inputLength = 10;
// vtkm::Id inputKeys[inputLength] = {0, 0, 0, 1, 1, 2, 3, 3, 3, 3};
// vtkm::Id inputValues[inputLength] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
// vtkm::Id init = 5;
//
// const vtkm::Id expectedLength = 10;
// vtkm::Id expectedValues[expectedLength] = {5, 6, 7, 5, 6, 5, 5, 6, 7, 8};
//
// IdArrayHandle keys = vtkm::cont::make_ArrayHandle(inputKeys, inputLength);
// IdArrayHandle values = vtkm::cont::make_ArrayHandle(inputValues, inputLength);
//
// IdArrayHandle valuesOut;
//
// Algorithm::ScanExclusiveByKey(keys, values, valuesOut, vtkm::Add(), init);
// std::cout << valuesOut.GetNumberOfValues() << std::endl;
// VTKM_TEST_ASSERT(valuesOut.GetNumberOfValues() == expectedLength,
// "Got wrong number of output values");
// for (auto i= 0; i < valuesOut.GetNumberOfValues(); i++) {
// std::cout << valuesOut.GetPortalConstControl().Get(i) << " ";
// }
// std::cout << std::endl;
//
// }
static VTKM_CONT void TestScanExclusiveByKey()
{
std::cout << "-------------------------------------------" << std::endl;
std::cout << "Testing Scan Exclusive By Key" << std::endl;
const vtkm::Id inputLength = 10;
vtkm::Id inputKeys[inputLength] = {0, 0, 0, 1, 1, 2, 3, 3, 3, 3};
vtkm::Id inputValues[inputLength] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
vtkm::Id init = 5;
const vtkm::Id expectedLength = 10;
vtkm::Id expectedValues[expectedLength] = {5, 6, 7, 5, 6, 5, 5, 6, 7, 8};
IdArrayHandle keys = vtkm::cont::make_ArrayHandle(inputKeys, inputLength);
IdArrayHandle values = vtkm::cont::make_ArrayHandle(inputValues, inputLength);
IdArrayHandle valuesOut;
Algorithm::ScanExclusiveByKey(keys, values, valuesOut, vtkm::Add(), init);
VTKM_TEST_ASSERT(valuesOut.GetNumberOfValues() == expectedLength,
"Got wrong number of output values");
for (auto i= 0; i < expectedLength; i++) {
const vtkm::Id v = valuesOut.GetPortalConstControl().Get(i);
VTKM_TEST_ASSERT(expectedValues[i] == v, "Incorrect scanned value");
}
}
static VTKM_CONT void TestScanInclusive()
{
@ -1966,7 +1965,7 @@ private:
TestScanInclusiveByKey();
//TestScanExclusiveByKey();
TestScanExclusiveByKey();
TestSort();
TestSortWithComparisonObject();