mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-09-16 17:22:55 +00:00
add both generic and Thrust ScanExclusiveByKey
This commit is contained in:
parent
e77f9fac6a
commit
7023266585
@ -717,6 +717,62 @@ private:
|
||||
|
||||
}
|
||||
|
||||
template<typename KeysPortal, typename ValuesPortal, typename OutputPortal>
|
||||
VTKM_CONT static
|
||||
typename ValuesPortal::ValueType ScanExclusiveByKeyPortal(const KeysPortal &keys,
|
||||
const ValuesPortal &values,
|
||||
const OutputPortal &output)
|
||||
{
|
||||
using KeyType = typename KeysPortal::ValueType;
|
||||
typedef typename OutputPortal::ValueType ValueType;
|
||||
return ScanExclusiveByKeyPortal(keys, values, output,
|
||||
vtkm::TypeTraits<ValueType>::ZeroInitialization(),
|
||||
::thrust::equal_to<KeyType>(),
|
||||
::thrust::plus<ValueType>());
|
||||
}
|
||||
|
||||
template<typename KeysPortal, typename ValuesPortal, typename OutputPortal, typename T,
|
||||
typename BinaryPredicate, typename AssociativeOperator>
|
||||
VTKM_CONT static
|
||||
typename ValuesPortal::ValueType ScanExclusiveByKeyPortal(const KeysPortal &keys,
|
||||
const ValuesPortal &values,
|
||||
const OutputPortal &output,
|
||||
T initValue,
|
||||
BinaryPredicate binary_predicate,
|
||||
AssociativeOperator binary_operator)
|
||||
{
|
||||
typedef typename KeysPortal::ValueType KeyType;
|
||||
vtkm::exec::cuda::internal::WrappedBinaryOperator<KeyType,
|
||||
BinaryPredicate> bpred(binary_predicate);
|
||||
typedef typename OutputPortal::ValueType ValueType;
|
||||
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType,
|
||||
AssociativeOperator> bop(binary_operator);
|
||||
|
||||
|
||||
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType
|
||||
IteratorType;
|
||||
try
|
||||
{
|
||||
IteratorType end = ::thrust::exclusive_scan_by_key(thrust::cuda::par,
|
||||
IteratorBegin(keys),
|
||||
IteratorEnd(keys),
|
||||
IteratorBegin(values),
|
||||
IteratorBegin(output),
|
||||
initValue,
|
||||
bpred,
|
||||
bop);
|
||||
return *(end-1);
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
throwAsVTKmException();
|
||||
return typename ValuesPortal::ValueType();
|
||||
}
|
||||
|
||||
//return the value at the last index in the array, as that is the sum
|
||||
|
||||
}
|
||||
|
||||
template<class ValuesPortal>
|
||||
VTKM_CONT static void SortPortal(const ValuesPortal &values)
|
||||
{
|
||||
@ -1217,12 +1273,66 @@ public:
|
||||
//use case breaks.
|
||||
keys.PrepareForInput(DeviceAdapterTag());
|
||||
values.PrepareForInput(DeviceAdapterTag());
|
||||
return ScanInclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
|
||||
return ScanExclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
|
||||
values.PrepareForInput(DeviceAdapterTag()),
|
||||
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
|
||||
binary_functor);
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename KIn, typename VIn, typename VOut>
|
||||
VTKM_CONT static T ScanExclusiveByKey(
|
||||
const vtkm::cont::ArrayHandle<T, KIn>& keys,
|
||||
const vtkm::cont::ArrayHandle<U, VIn>& values,
|
||||
vtkm::cont::ArrayHandle<U, VOut>& output)
|
||||
{
|
||||
const vtkm::Id numberOfValues = keys.GetNumberOfValues();
|
||||
if (numberOfValues <= 0)
|
||||
{
|
||||
output.PrepareForOutput(0, DeviceAdapterTag());
|
||||
return vtkm::TypeTraits<T>::ZeroInitialization();
|
||||
}
|
||||
|
||||
//We need call PrepareForInput on the input argument before invoking a
|
||||
//function. The order of execution of parameters of a function is undefined
|
||||
//so we need to make sure input is called before output, or else in-place
|
||||
//use case breaks.
|
||||
keys.PrepareForInput(DeviceAdapterTag());
|
||||
values.PrepareForInput(DeviceAdapterTag());
|
||||
return ScanExnclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
|
||||
values.PrepareForInput(DeviceAdapterTag()),
|
||||
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
|
||||
vtkm::TypeTraits<T>::ZeroInitialization());
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename KIn, typename VIn, typename VOut,
|
||||
typename BinaryFunctor>
|
||||
VTKM_CONT static T ScanExclusiveByKey(
|
||||
const vtkm::cont::ArrayHandle<T, KIn>& keys,
|
||||
const vtkm::cont::ArrayHandle<U, VIn>& values,
|
||||
vtkm::cont::ArrayHandle<U, VOut>& output,
|
||||
BinaryFunctor binary_functor,
|
||||
const U& initialValue)
|
||||
{
|
||||
const vtkm::Id numberOfValues = keys.GetNumberOfValues();
|
||||
if (numberOfValues <= 0)
|
||||
{
|
||||
output.PrepareForOutput(0, DeviceAdapterTag());
|
||||
return vtkm::TypeTraits<T>::ZeroInitialization();
|
||||
}
|
||||
|
||||
//We need call PrepareForInput on the input argument before invoking a
|
||||
//function. The order of execution of parameters of a function is undefined
|
||||
//so we need to make sure input is called before output, or else in-place
|
||||
//use case breaks.
|
||||
keys.PrepareForInput(DeviceAdapterTag());
|
||||
values.PrepareForInput(DeviceAdapterTag());
|
||||
return ScanExclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
|
||||
values.PrepareForInput(DeviceAdapterTag()),
|
||||
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
|
||||
initialValue,
|
||||
::thrust::equal_to<T>(),
|
||||
binary_functor);
|
||||
}
|
||||
// Because of some funny code conversions in nvcc, kernels for devices have to
|
||||
// be public.
|
||||
#ifndef VTKM_CUDA
|
||||
|
@ -551,20 +551,62 @@ public:
|
||||
// Scan Exclusive By Key
|
||||
template<typename T, typename U, typename KIn, typename VIn, typename VOut,
|
||||
class BinaryFunctor>
|
||||
VTKM_CONT static T ScanExclusiveByKey(
|
||||
VTKM_CONT static void ScanExclusiveByKey(
|
||||
const vtkm::cont::ArrayHandle<T, KIn>& keys,
|
||||
const vtkm::cont::ArrayHandle<U, VIn>& values,
|
||||
vtkm::cont::ArrayHandle<U ,VOut>& output,
|
||||
BinaryFunctor binaryFunctor,
|
||||
const T& initialValue)
|
||||
const U& initialValue)
|
||||
{
|
||||
// TODO: add DerivedAlgorithm?
|
||||
ScanInclusiveByKey(keys, values, output, binaryFunctor);
|
||||
// 0. TODO: special case for 1 element input?
|
||||
vtkm::Id numberOfKeys = keys.GetNumberOfValues();
|
||||
|
||||
// 1. Create head flags
|
||||
//we need to determine based on the keys what is the keystate for
|
||||
//each key. The states are start, middle, end of a series and the special
|
||||
//state start and end of a series
|
||||
vtkm::cont::ArrayHandle< ReduceKeySeriesStates > keystate;
|
||||
|
||||
{
|
||||
typedef typename vtkm::cont::ArrayHandle<T,KIn>::template ExecutionTypes<DeviceAdapterTag>
|
||||
::PortalConst InputPortalType;
|
||||
|
||||
typedef typename vtkm::cont::ArrayHandle< ReduceKeySeriesStates >::template ExecutionTypes<DeviceAdapterTag>
|
||||
::Portal KeyStatePortalType;
|
||||
|
||||
InputPortalType inputPortal = keys.PrepareForInput(DeviceAdapterTag());
|
||||
KeyStatePortalType keyStatePortal = keystate.PrepareForOutput(numberOfKeys,
|
||||
DeviceAdapterTag());
|
||||
ReduceStencilGeneration<InputPortalType, KeyStatePortalType> kernel(inputPortal, keyStatePortal);
|
||||
DerivedAlgorithm::Schedule(kernel, numberOfKeys);
|
||||
}
|
||||
|
||||
// 2. Shift input and initialize elements at head flags position to initValue
|
||||
typedef typename vtkm::cont::ArrayHandle<T,vtkm::cont::StorageTagBasic> TempArrayType;
|
||||
typedef typename vtkm::cont::ArrayHandle<T,vtkm::cont::StorageTagBasic>::template ExecutionTypes<DeviceAdapterTag>::Portal TempPortalType;
|
||||
TempArrayType temp;
|
||||
{
|
||||
typedef typename vtkm::cont::ArrayHandle<T,KIn>::template ExecutionTypes<DeviceAdapterTag>
|
||||
::PortalConst InputPortalType;
|
||||
|
||||
typedef typename vtkm::cont::ArrayHandle< ReduceKeySeriesStates >::template ExecutionTypes<DeviceAdapterTag>
|
||||
::PortalConst KeyStatePortalType;
|
||||
|
||||
InputPortalType inputPortal = values.PrepareForInput(DeviceAdapterTag());
|
||||
KeyStatePortalType keyStatePortal = keystate.PrepareForInput(DeviceAdapterTag());
|
||||
TempPortalType tempPortal = temp.PrepareForOutput(numberOfKeys,
|
||||
DeviceAdapterTag());
|
||||
|
||||
ShiftCopyAndInit<U, InputPortalType, KeyStatePortalType, TempPortalType>
|
||||
kernel(inputPortal, keyStatePortal, tempPortal, initialValue);
|
||||
DerivedAlgorithm::Schedule(kernel, numberOfKeys);
|
||||
}
|
||||
// 3. Perform an ScanInclusiveByKey
|
||||
DerivedAlgorithm::ScanInclusiveByKey(keys, temp, output, binaryFunctor);
|
||||
}
|
||||
|
||||
template<typename T, typename U, class KIn, typename VIn, typename VOut>
|
||||
VTKM_CONT static T ScanExclusiveByKey(
|
||||
VTKM_CONT static void ScanExclusiveByKey(
|
||||
const vtkm::cont::ArrayHandle<T, KIn>& keys,
|
||||
const vtkm::cont::ArrayHandle<U, VIn>& values,
|
||||
vtkm::cont::ArrayHandle<U, VOut>& output)
|
||||
|
@ -298,6 +298,34 @@ struct ReduceByKeyUnaryStencilOp
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <typename T, typename InputPortalType,
|
||||
typename KeyStatePortalType, typename OutputPortalType>
|
||||
struct ShiftCopyAndInit : vtkm::exec::FunctorBase
|
||||
{
|
||||
InputPortalType Input;
|
||||
KeyStatePortalType KeyState;
|
||||
OutputPortalType Output;
|
||||
T initValue;
|
||||
|
||||
ShiftCopyAndInit(const InputPortalType& _input,
|
||||
const KeyStatePortalType &kstate,
|
||||
OutputPortalType& _output,
|
||||
T _init) : Input(_input),
|
||||
KeyState(kstate),
|
||||
Output(_output),
|
||||
initValue(_init) {}
|
||||
|
||||
void operator()(vtkm::Id index) const
|
||||
{
|
||||
if (this->KeyState.Get(index).fStart) {
|
||||
Output.Set(index, initValue);
|
||||
} else {
|
||||
Output.Set(index, Input.Get(index-1));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<class InputPortalType, class OutputPortalType>
|
||||
struct CopyKernel
|
||||
{
|
||||
@ -808,23 +836,27 @@ struct InclusiveToExclusiveKernel : vtkm::exec::FunctorBase
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InPortalType, typename OutPortalType, typename BinaryFunctor>
|
||||
template <typename InPortalType, typename OutPortalType, typename StencilPortalType,
|
||||
typename BinaryFunctor>
|
||||
struct InclusiveToExclusiveByKeyKernel : vtkm::exec::FunctorBase
|
||||
{
|
||||
typedef typename InPortalType::ValueType ValueType;
|
||||
|
||||
InPortalType InPortal;
|
||||
OutPortalType OutPortal;
|
||||
StencilPortalType StencilPortal;
|
||||
BinaryFunctor BinaryOperator;
|
||||
ValueType InitialValue;
|
||||
|
||||
VTKM_CONT
|
||||
InclusiveToExclusiveByKeyKernel(const InPortalType &inPortal,
|
||||
const OutPortalType &outPortal,
|
||||
BinaryFunctor &binaryOperator,
|
||||
ValueType initialValue)
|
||||
const OutPortalType &outPortal,
|
||||
const StencilPortalType &stencilPortal,
|
||||
BinaryFunctor &binaryOperator,
|
||||
ValueType initialValue)
|
||||
: InPortal(inPortal),
|
||||
OutPortal(outPortal),
|
||||
StencilPortal(stencilPortal),
|
||||
BinaryOperator(binaryOperator),
|
||||
InitialValue(initialValue)
|
||||
{ }
|
||||
@ -833,7 +865,7 @@ struct InclusiveToExclusiveByKeyKernel : vtkm::exec::FunctorBase
|
||||
VTKM_EXEC
|
||||
void operator()(vtkm::Id index) const
|
||||
{
|
||||
ValueType result = (index == 0) ? this->InitialValue :
|
||||
ValueType result = (this->StencilPortal.Get(index).fState == true) ? this->InitialValue :
|
||||
this->BinaryOperator(this->InitialValue, this->InPortal.Get(index - 1));
|
||||
this->OutPortal.Set(index, result);
|
||||
}
|
||||
|
@ -1333,34 +1333,33 @@ private:
|
||||
|
||||
}
|
||||
|
||||
// static VTKM_CONT void TestScanExclusiveByKey()
|
||||
// {
|
||||
// std::cout << "-------------------------------------------" << std::endl;
|
||||
// std::cout << "Testing Scan Exclusive By Key" << std::endl;
|
||||
//
|
||||
// const vtkm::Id inputLength = 10;
|
||||
// vtkm::Id inputKeys[inputLength] = {0, 0, 0, 1, 1, 2, 3, 3, 3, 3};
|
||||
// vtkm::Id inputValues[inputLength] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
// vtkm::Id init = 5;
|
||||
//
|
||||
// const vtkm::Id expectedLength = 10;
|
||||
// vtkm::Id expectedValues[expectedLength] = {5, 6, 7, 5, 6, 5, 5, 6, 7, 8};
|
||||
//
|
||||
// IdArrayHandle keys = vtkm::cont::make_ArrayHandle(inputKeys, inputLength);
|
||||
// IdArrayHandle values = vtkm::cont::make_ArrayHandle(inputValues, inputLength);
|
||||
//
|
||||
// IdArrayHandle valuesOut;
|
||||
//
|
||||
// Algorithm::ScanExclusiveByKey(keys, values, valuesOut, vtkm::Add(), init);
|
||||
// std::cout << valuesOut.GetNumberOfValues() << std::endl;
|
||||
// VTKM_TEST_ASSERT(valuesOut.GetNumberOfValues() == expectedLength,
|
||||
// "Got wrong number of output values");
|
||||
// for (auto i= 0; i < valuesOut.GetNumberOfValues(); i++) {
|
||||
// std::cout << valuesOut.GetPortalConstControl().Get(i) << " ";
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
//
|
||||
// }
|
||||
static VTKM_CONT void TestScanExclusiveByKey()
|
||||
{
|
||||
std::cout << "-------------------------------------------" << std::endl;
|
||||
std::cout << "Testing Scan Exclusive By Key" << std::endl;
|
||||
|
||||
const vtkm::Id inputLength = 10;
|
||||
vtkm::Id inputKeys[inputLength] = {0, 0, 0, 1, 1, 2, 3, 3, 3, 3};
|
||||
vtkm::Id inputValues[inputLength] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
vtkm::Id init = 5;
|
||||
|
||||
const vtkm::Id expectedLength = 10;
|
||||
vtkm::Id expectedValues[expectedLength] = {5, 6, 7, 5, 6, 5, 5, 6, 7, 8};
|
||||
|
||||
IdArrayHandle keys = vtkm::cont::make_ArrayHandle(inputKeys, inputLength);
|
||||
IdArrayHandle values = vtkm::cont::make_ArrayHandle(inputValues, inputLength);
|
||||
|
||||
IdArrayHandle valuesOut;
|
||||
|
||||
Algorithm::ScanExclusiveByKey(keys, values, valuesOut, vtkm::Add(), init);
|
||||
|
||||
VTKM_TEST_ASSERT(valuesOut.GetNumberOfValues() == expectedLength,
|
||||
"Got wrong number of output values");
|
||||
for (auto i= 0; i < expectedLength; i++) {
|
||||
const vtkm::Id v = valuesOut.GetPortalConstControl().Get(i);
|
||||
VTKM_TEST_ASSERT(expectedValues[i] == v, "Incorrect scanned value");
|
||||
}
|
||||
}
|
||||
|
||||
static VTKM_CONT void TestScanInclusive()
|
||||
{
|
||||
@ -1966,7 +1965,7 @@ private:
|
||||
|
||||
TestScanInclusiveByKey();
|
||||
|
||||
//TestScanExclusiveByKey();
|
||||
TestScanExclusiveByKey();
|
||||
|
||||
TestSort();
|
||||
TestSortWithComparisonObject();
|
||||
|
Loading…
Reference in New Issue
Block a user