Reduce the amount of typedef statements in DeviceAdapters

By using the auto keyword and decltype we can reduce the number of
complex typedefs that exist when writing device adapter algorithms.
The goal being that it is easier for developers to see the actual
algorithms being implemented, by reducing the amount of template
'noise'.
This commit is contained in:
Robert Maynard 2017-08-10 16:13:55 -04:00
parent e095831199
commit 89f439999a
4 changed files with 195 additions and 382 deletions

@ -216,18 +216,16 @@ private:
OutputPortal output, OutputPortal output,
UnaryPredicate unary_predicate) UnaryPredicate unary_predicate)
{ {
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType IteratorType; using ValueType = typename StencilPortal::ValueType;
IteratorType outputBegin = IteratorBegin(output); auto outputBegin = IteratorBegin(output);
typedef typename StencilPortal::ValueType ValueType;
vtkm::exec::cuda::internal::WrappedUnaryPredicate<ValueType, UnaryPredicate> up( vtkm::exec::cuda::internal::WrappedUnaryPredicate<ValueType, UnaryPredicate> up(
unary_predicate); unary_predicate);
try try
{ {
IteratorType newLast = ::thrust::copy_if( auto newLast = ::thrust::copy_if(
thrust::cuda::par, valuesBegin, valuesEnd, IteratorBegin(stencil), outputBegin, up); thrust::cuda::par, valuesBegin, valuesEnd, IteratorBegin(stencil), outputBegin, up);
return static_cast<vtkm::Id>(::thrust::distance(outputBegin, newLast)); return static_cast<vtkm::Id>(::thrust::distance(outputBegin, newLast));
} }
@ -273,7 +271,7 @@ private:
const ValuesPortal& values, const ValuesPortal& values,
const OutputPortal& output) const OutputPortal& output)
{ {
typedef typename ValuesPortal::ValueType ValueType; using ValueType = typename ValuesPortal::ValueType;
LowerBoundsPortal(input, values, output, ::thrust::less<ValueType>()); LowerBoundsPortal(input, values, output, ::thrust::less<ValueType>());
} }
@ -281,7 +279,7 @@ private:
VTKM_CONT static void LowerBoundsPortal(const InputPortal& input, VTKM_CONT static void LowerBoundsPortal(const InputPortal& input,
const OutputPortal& values_output) const OutputPortal& values_output)
{ {
typedef typename InputPortal::ValueType ValueType; using ValueType = typename InputPortal::ValueType;
LowerBoundsPortal(input, values_output, values_output, ::thrust::less<ValueType>()); LowerBoundsPortal(input, values_output, values_output, ::thrust::less<ValueType>());
} }
@ -291,7 +289,7 @@ private:
const OutputPortal& output, const OutputPortal& output,
BinaryCompare binary_compare) BinaryCompare binary_compare)
{ {
typedef typename InputPortal::ValueType ValueType; using ValueType = typename InputPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryPredicate<ValueType, BinaryCompare> bop( vtkm::exec::cuda::internal::WrappedBinaryPredicate<ValueType, BinaryCompare> bop(
binary_compare); binary_compare);
@ -388,17 +386,14 @@ private:
const ValueOutputPortal& values_output, const ValueOutputPortal& values_output,
BinaryFunctor binary_functor) BinaryFunctor binary_functor)
{ {
typedef typename detail::IteratorTraits<KeysOutputPortal>::IteratorType KeysIteratorType; auto keys_out_begin = IteratorBegin(keys_output);
typedef typename detail::IteratorTraits<ValueOutputPortal>::IteratorType ValuesIteratorType; auto values_out_begin = IteratorBegin(values_output);
KeysIteratorType keys_out_begin = IteratorBegin(keys_output); ::thrust::pair<decltype(keys_out_begin), decltype(values_out_begin)> result_iterators;
ValuesIteratorType values_out_begin = IteratorBegin(values_output);
::thrust::pair<KeysIteratorType, ValuesIteratorType> result_iterators;
::thrust::equal_to<typename KeysPortal::ValueType> binaryPredicate; ::thrust::equal_to<typename KeysPortal::ValueType> binaryPredicate;
typedef typename ValuesPortal::ValueType ValueType; using ValueType = typename ValuesPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, BinaryFunctor> bop(binary_functor); vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, BinaryFunctor> bop(binary_functor);
try try
@ -424,7 +419,7 @@ private:
VTKM_CONT static typename InputPortal::ValueType ScanExclusivePortal(const InputPortal& input, VTKM_CONT static typename InputPortal::ValueType ScanExclusivePortal(const InputPortal& input,
const OutputPortal& output) const OutputPortal& output)
{ {
typedef typename OutputPortal::ValueType ValueType; using ValueType = typename OutputPortal::ValueType;
return ScanExclusivePortal(input, return ScanExclusivePortal(input,
output, output,
@ -441,7 +436,7 @@ private:
{ {
// Use iterator to get value so that thrust device_ptr has chance to handle // Use iterator to get value so that thrust device_ptr has chance to handle
// data on device. // data on device.
typedef typename OutputPortal::ValueType ValueType; using ValueType = typename OutputPortal::ValueType;
//we have size three so that we can store the origin end value, the //we have size three so that we can store the origin end value, the
//new end value, and the sum of those two //new end value, and the sum of those two
@ -456,8 +451,7 @@ private:
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, BinaryFunctor> bop(binaryOp); vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, BinaryFunctor> bop(binaryOp);
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType IteratorType; auto end = ::thrust::exclusive_scan(thrust::cuda::par,
IteratorType end = ::thrust::exclusive_scan(thrust::cuda::par,
IteratorBegin(input), IteratorBegin(input),
IteratorEnd(input), IteratorEnd(input),
IteratorBegin(output), IteratorBegin(output),
@ -483,7 +477,7 @@ private:
VTKM_CONT static typename InputPortal::ValueType ScanInclusivePortal(const InputPortal& input, VTKM_CONT static typename InputPortal::ValueType ScanInclusivePortal(const InputPortal& input,
const OutputPortal& output) const OutputPortal& output)
{ {
typedef typename OutputPortal::ValueType ValueType; using ValueType = typename OutputPortal::ValueType;
return ScanInclusivePortal(input, output, ::thrust::plus<ValueType>()); return ScanInclusivePortal(input, output, ::thrust::plus<ValueType>());
} }
@ -492,14 +486,12 @@ private:
const OutputPortal& output, const OutputPortal& output,
BinaryFunctor binary_functor) BinaryFunctor binary_functor)
{ {
typedef typename OutputPortal::ValueType ValueType; using ValueType = typename OutputPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, BinaryFunctor> bop(binary_functor); vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, BinaryFunctor> bop(binary_functor);
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType IteratorType;
try try
{ {
IteratorType end = ::thrust::inclusive_scan( auto end = ::thrust::inclusive_scan(
thrust::cuda::par, IteratorBegin(input), IteratorEnd(input), IteratorBegin(output), bop); thrust::cuda::par, IteratorBegin(input), IteratorEnd(input), IteratorBegin(output), bop);
return *(end - 1); return *(end - 1);
} }
@ -518,7 +510,7 @@ private:
const OutputPortal& output) const OutputPortal& output)
{ {
using KeyType = typename KeysPortal::ValueType; using KeyType = typename KeysPortal::ValueType;
typedef typename OutputPortal::ValueType ValueType; using ValueType = typename OutputPortal::ValueType;
ScanInclusiveByKeyPortal( ScanInclusiveByKeyPortal(
keys, values, output, ::thrust::equal_to<KeyType>(), ::thrust::plus<ValueType>()); keys, values, output, ::thrust::equal_to<KeyType>(), ::thrust::plus<ValueType>());
} }
@ -534,14 +526,13 @@ private:
BinaryPredicate binary_predicate, BinaryPredicate binary_predicate,
AssociativeOperator binary_operator) AssociativeOperator binary_operator)
{ {
typedef typename KeysPortal::ValueType KeyType; using KeyType = typename KeysPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<KeyType, BinaryPredicate> bpred( vtkm::exec::cuda::internal::WrappedBinaryOperator<KeyType, BinaryPredicate> bpred(
binary_predicate); binary_predicate);
typedef typename OutputPortal::ValueType ValueType; using ValueType = typename OutputPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, AssociativeOperator> bop( vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, AssociativeOperator> bop(
binary_operator); binary_operator);
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType IteratorType;
try try
{ {
::thrust::inclusive_scan_by_key(thrust::cuda::par, ::thrust::inclusive_scan_by_key(thrust::cuda::par,
@ -564,7 +555,7 @@ private:
const OutputPortal& output) const OutputPortal& output)
{ {
using KeyType = typename KeysPortal::ValueType; using KeyType = typename KeysPortal::ValueType;
typedef typename OutputPortal::ValueType ValueType; using ValueType = typename OutputPortal::ValueType;
ScanExclusiveByKeyPortal(keys, ScanExclusiveByKeyPortal(keys,
values, values,
output, output,
@ -586,14 +577,12 @@ private:
BinaryPredicate binary_predicate, BinaryPredicate binary_predicate,
AssociativeOperator binary_operator) AssociativeOperator binary_operator)
{ {
typedef typename KeysPortal::ValueType KeyType; using KeyType = typename KeysPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<KeyType, BinaryPredicate> bpred( vtkm::exec::cuda::internal::WrappedBinaryOperator<KeyType, BinaryPredicate> bpred(
binary_predicate); binary_predicate);
typedef typename OutputPortal::ValueType ValueType; using ValueType = typename OutputPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, AssociativeOperator> bop( vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, AssociativeOperator> bop(
binary_operator); binary_operator);
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType IteratorType;
try try
{ {
::thrust::exclusive_scan_by_key(thrust::cuda::par, ::thrust::exclusive_scan_by_key(thrust::cuda::par,
@ -614,14 +603,14 @@ private:
template <class ValuesPortal> template <class ValuesPortal>
VTKM_CONT static void SortPortal(const ValuesPortal& values) VTKM_CONT static void SortPortal(const ValuesPortal& values)
{ {
typedef typename ValuesPortal::ValueType ValueType; using ValueType = typename ValuesPortal::ValueType;
SortPortal(values, ::thrust::less<ValueType>()); SortPortal(values, ::thrust::less<ValueType>());
} }
template <class ValuesPortal, class BinaryCompare> template <class ValuesPortal, class BinaryCompare>
VTKM_CONT static void SortPortal(const ValuesPortal& values, BinaryCompare binary_compare) VTKM_CONT static void SortPortal(const ValuesPortal& values, BinaryCompare binary_compare)
{ {
typedef typename ValuesPortal::ValueType ValueType; using ValueType = typename ValuesPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryPredicate<ValueType, BinaryCompare> bop( vtkm::exec::cuda::internal::WrappedBinaryPredicate<ValueType, BinaryCompare> bop(
binary_compare); binary_compare);
try try
@ -646,7 +635,7 @@ private:
const ValuesPortal& values, const ValuesPortal& values,
BinaryCompare binary_compare) BinaryCompare binary_compare)
{ {
typedef typename KeysPortal::ValueType ValueType; using ValueType = typename KeysPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryPredicate<ValueType, BinaryCompare> bop( vtkm::exec::cuda::internal::WrappedBinaryPredicate<ValueType, BinaryCompare> bop(
binary_compare); binary_compare);
try try
@ -663,11 +652,10 @@ private:
template <class ValuesPortal> template <class ValuesPortal>
VTKM_CONT static vtkm::Id UniquePortal(const ValuesPortal values) VTKM_CONT static vtkm::Id UniquePortal(const ValuesPortal values)
{ {
typedef typename detail::IteratorTraits<ValuesPortal>::IteratorType IteratorType;
try try
{ {
IteratorType begin = IteratorBegin(values); auto begin = IteratorBegin(values);
IteratorType newLast = ::thrust::unique(thrust::cuda::par, begin, IteratorEnd(values)); auto newLast = ::thrust::unique(thrust::cuda::par, begin, IteratorEnd(values));
return static_cast<vtkm::Id>(::thrust::distance(begin, newLast)); return static_cast<vtkm::Id>(::thrust::distance(begin, newLast));
} }
catch (...) catch (...)
@ -680,15 +668,13 @@ private:
template <class ValuesPortal, class BinaryCompare> template <class ValuesPortal, class BinaryCompare>
VTKM_CONT static vtkm::Id UniquePortal(const ValuesPortal values, BinaryCompare binary_compare) VTKM_CONT static vtkm::Id UniquePortal(const ValuesPortal values, BinaryCompare binary_compare)
{ {
typedef typename detail::IteratorTraits<ValuesPortal>::IteratorType IteratorType; using ValueType = typename ValuesPortal::ValueType;
typedef typename ValuesPortal::ValueType ValueType;
vtkm::exec::cuda::internal::WrappedBinaryPredicate<ValueType, BinaryCompare> bop( vtkm::exec::cuda::internal::WrappedBinaryPredicate<ValueType, BinaryCompare> bop(
binary_compare); binary_compare);
try try
{ {
IteratorType begin = IteratorBegin(values); auto begin = IteratorBegin(values);
IteratorType newLast = ::thrust::unique(thrust::cuda::par, begin, IteratorEnd(values), bop); auto newLast = ::thrust::unique(thrust::cuda::par, begin, IteratorEnd(values), bop);
return static_cast<vtkm::Id>(::thrust::distance(begin, newLast)); return static_cast<vtkm::Id>(::thrust::distance(begin, newLast));
} }
catch (...) catch (...)
@ -949,8 +935,8 @@ public:
//function. The order of execution of parameters of a function is undefined //function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place //so we need to make sure input is called before output, or else in-place
//use case breaks. //use case breaks.
input.PrepareForInput(DeviceAdapterTag()); auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
return ScanExclusivePortal(input.PrepareForInput(DeviceAdapterTag()), return ScanExclusivePortal(inputPortal,
output.PrepareForOutput(numberOfValues, DeviceAdapterTag())); output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
} }
@ -971,8 +957,8 @@ public:
//function. The order of execution of parameters of a function is undefined //function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place //so we need to make sure input is called before output, or else in-place
//use case breaks. //use case breaks.
input.PrepareForInput(DeviceAdapterTag()); auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
return ScanExclusivePortal(input.PrepareForInput(DeviceAdapterTag()), return ScanExclusivePortal(inputPortal,
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()), output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
binary_functor, binary_functor,
initialValue); initialValue);
@ -993,8 +979,8 @@ public:
//function. The order of execution of parameters of a function is undefined //function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place //so we need to make sure input is called before output, or else in-place
//use case breaks. //use case breaks.
input.PrepareForInput(DeviceAdapterTag()); auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
return ScanInclusivePortal(input.PrepareForInput(DeviceAdapterTag()), return ScanInclusivePortal(inputPortal,
output.PrepareForOutput(numberOfValues, DeviceAdapterTag())); output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
} }
@ -1014,10 +1000,9 @@ public:
//function. The order of execution of parameters of a function is undefined //function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place //so we need to make sure input is called before output, or else in-place
//use case breaks. //use case breaks.
input.PrepareForInput(DeviceAdapterTag()); auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
return ScanInclusivePortal(input.PrepareForInput(DeviceAdapterTag()), return ScanInclusivePortal(
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()), inputPortal, output.PrepareForOutput(numberOfValues, DeviceAdapterTag()), binary_functor);
binary_functor);
} }
template <typename T, typename U, typename KIn, typename VIn, typename VOut> template <typename T, typename U, typename KIn, typename VIn, typename VOut>
@ -1035,11 +1020,10 @@ public:
//function. The order of execution of parameters of a function is undefined //function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place //so we need to make sure input is called before output, or else in-place
//use case breaks. //use case breaks.
keys.PrepareForInput(DeviceAdapterTag()); auto keysPortal = keys.PrepareForInput(DeviceAdapterTag());
values.PrepareForInput(DeviceAdapterTag()); auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
ScanInclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()), ScanInclusiveByKeyPortal(
values.PrepareForInput(DeviceAdapterTag()), keysPortal, valuesPortal, output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
} }
template <typename T, template <typename T,
@ -1063,10 +1047,10 @@ public:
//function. The order of execution of parameters of a function is undefined //function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place //so we need to make sure input is called before output, or else in-place
//use case breaks. //use case breaks.
keys.PrepareForInput(DeviceAdapterTag()); auto keysPortal = keys.PrepareForInput(DeviceAdapterTag());
values.PrepareForInput(DeviceAdapterTag()); auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
ScanInclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()), ScanInclusiveByKeyPortal(keysPortal,
values.PrepareForInput(DeviceAdapterTag()), valuesPortal,
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()), output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
::thrust::equal_to<T>(), ::thrust::equal_to<T>(),
binary_functor); binary_functor);
@ -1088,10 +1072,10 @@ public:
//function. The order of execution of parameters of a function is undefined //function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place //so we need to make sure input is called before output, or else in-place
//use case breaks. //use case breaks.
keys.PrepareForInput(DeviceAdapterTag()); auto keysPortal = keys.PrepareForInput(DeviceAdapterTag());
values.PrepareForInput(DeviceAdapterTag()); auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
ScanExnclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()), ScanExnclusiveByKeyPortal(keysPortal,
values.PrepareForInput(DeviceAdapterTag()), valuesPortal,
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()), output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
vtkm::TypeTraits<T>::ZeroInitialization(), vtkm::TypeTraits<T>::ZeroInitialization(),
vtkm::Add()); vtkm::Add());
@ -1120,10 +1104,10 @@ public:
//function. The order of execution of parameters of a function is undefined //function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place //so we need to make sure input is called before output, or else in-place
//use case breaks. //use case breaks.
keys.PrepareForInput(DeviceAdapterTag()); auto keysPortal = keys.PrepareForInput(DeviceAdapterTag());
values.PrepareForInput(DeviceAdapterTag()); auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
ScanExclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()), ScanExclusiveByKeyPortal(keysPortal,
values.PrepareForInput(DeviceAdapterTag()), valuesPortal,
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()), output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
initialValue, initialValue,
::thrust::equal_to<T>(), ::thrust::equal_to<T>(),

@ -102,16 +102,14 @@ private:
template <typename T, class CIn> template <typename T, class CIn>
VTKM_CONT static T GetExecutionValue(const vtkm::cont::ArrayHandle<T, CIn>& input, vtkm::Id index) VTKM_CONT static T GetExecutionValue(const vtkm::cont::ArrayHandle<T, CIn>& input, vtkm::Id index)
{ {
using InputArrayType = vtkm::cont::ArrayHandle<T, CIn>;
using OutputArrayType = vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic>; using OutputArrayType = vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic>;
OutputArrayType output; OutputArrayType output;
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
auto outputPortal = output.PrepareForOutput(1, DeviceAdapterTag());
CopyKernel<typename InputArrayType::template ExecutionTypes<DeviceAdapterTag>::PortalConst, CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(
typename OutputArrayType::template ExecutionTypes<DeviceAdapterTag>::Portal> inputPortal, outputPortal, index);
kernel(input.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(1, DeviceAdapterTag()),
index);
DerivedAlgorithm::Schedule(kernel, 1); DerivedAlgorithm::Schedule(kernel, 1);
@ -125,14 +123,11 @@ public:
VTKM_CONT static void Copy(const vtkm::cont::ArrayHandle<T, CIn>& input, VTKM_CONT static void Copy(const vtkm::cont::ArrayHandle<T, CIn>& input,
vtkm::cont::ArrayHandle<U, COut>& output) vtkm::cont::ArrayHandle<U, COut>& output)
{ {
using CopyKernelType = CopyKernel<
typename vtkm::cont::ArrayHandle<T,
CIn>::template ExecutionTypes<DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<U, COut>::template ExecutionTypes<DeviceAdapterTag>::Portal>;
const vtkm::Id inSize = input.GetNumberOfValues(); const vtkm::Id inSize = input.GetNumberOfValues();
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
auto outputPortal = output.PrepareForOutput(inSize, DeviceAdapterTag());
CopyKernelType kernel(input.PrepareForInput(DeviceAdapterTag()), CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(inputPortal, outputPortal);
output.PrepareForOutput(inSize, DeviceAdapterTag()));
DerivedAlgorithm::Schedule(kernel, inSize); DerivedAlgorithm::Schedule(kernel, inSize);
} }
@ -150,35 +145,23 @@ public:
using IndexArrayType = vtkm::cont::ArrayHandle<vtkm::Id, vtkm::cont::StorageTagBasic>; using IndexArrayType = vtkm::cont::ArrayHandle<vtkm::Id, vtkm::cont::StorageTagBasic>;
IndexArrayType indices; IndexArrayType indices;
using StencilPortalType = auto stencilPortal = stencil.PrepareForInput(DeviceAdapterTag());
typename vtkm::cont::ArrayHandle<U, CStencil>::template ExecutionTypes< auto indexPortal = indices.PrepareForOutput(arrayLength, DeviceAdapterTag());
DeviceAdapterTag>::PortalConst;
StencilPortalType stencilPortal = stencil.PrepareForInput(DeviceAdapterTag());
using IndexPortalType = StencilToIndexFlagKernel<decltype(stencilPortal), decltype(indexPortal), UnaryPredicate>
typename IndexArrayType::template ExecutionTypes<DeviceAdapterTag>::Portal; indexKernel(stencilPortal, indexPortal, unary_predicate);
IndexPortalType indexPortal = indices.PrepareForOutput(arrayLength, DeviceAdapterTag());
StencilToIndexFlagKernel<StencilPortalType, IndexPortalType, UnaryPredicate> indexKernel(
stencilPortal, indexPortal, unary_predicate);
DerivedAlgorithm::Schedule(indexKernel, arrayLength); DerivedAlgorithm::Schedule(indexKernel, arrayLength);
vtkm::Id outArrayLength = DerivedAlgorithm::ScanExclusive(indices, indices); vtkm::Id outArrayLength = DerivedAlgorithm::ScanExclusive(indices, indices);
using InputPortalType = auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
typename vtkm::cont::ArrayHandle<T, auto outputPortal = output.PrepareForOutput(outArrayLength, DeviceAdapterTag());
CIn>::template ExecutionTypes<DeviceAdapterTag>::PortalConst;
InputPortalType inputPortal = input.PrepareForInput(DeviceAdapterTag());
using OutputPortalType = CopyIfKernel<decltype(inputPortal),
typename vtkm::cont::ArrayHandle<T, COut>::template ExecutionTypes<DeviceAdapterTag>::Portal; decltype(stencilPortal),
OutputPortalType outputPortal = output.PrepareForOutput(outArrayLength, DeviceAdapterTag()); decltype(indexPortal),
decltype(outputPortal),
CopyIfKernel<InputPortalType,
StencilPortalType,
IndexPortalType,
OutputPortalType,
UnaryPredicate> UnaryPredicate>
copyKernel(inputPortal, stencilPortal, indexPortal, outputPortal, unary_predicate); copyKernel(inputPortal, stencilPortal, indexPortal, outputPortal, unary_predicate);
DerivedAlgorithm::Schedule(copyKernel, arrayLength); DerivedAlgorithm::Schedule(copyKernel, arrayLength);
@ -202,11 +185,6 @@ public:
vtkm::cont::ArrayHandle<U, COut>& output, vtkm::cont::ArrayHandle<U, COut>& output,
vtkm::Id outputIndex = 0) vtkm::Id outputIndex = 0)
{ {
using CopyKernel = CopyKernel<
typename vtkm::cont::ArrayHandle<T,
CIn>::template ExecutionTypes<DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<U, COut>::template ExecutionTypes<DeviceAdapterTag>::Portal>;
const vtkm::Id inSize = input.GetNumberOfValues(); const vtkm::Id inSize = input.GetNumberOfValues();
if (inputStartIndex < 0 || numberOfElementsToCopy < 0 || outputIndex < 0 || if (inputStartIndex < 0 || numberOfElementsToCopy < 0 || outputIndex < 0 ||
inputStartIndex >= inSize) inputStartIndex >= inSize)
@ -238,10 +216,11 @@ public:
} }
} }
CopyKernel kernel(input.PrepareForInput(DeviceAdapterTag()), auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
output.PrepareForInPlace(DeviceAdapterTag()), auto outputPortal = output.PrepareForInPlace(DeviceAdapterTag());
inputStartIndex,
outputIndex); CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(
inputPortal, outputPortal, inputStartIndex, outputIndex);
DerivedAlgorithm::Schedule(kernel, numberOfElementsToCopy); DerivedAlgorithm::Schedule(kernel, numberOfElementsToCopy);
return true; return true;
} }
@ -255,15 +234,12 @@ public:
{ {
vtkm::Id arraySize = values.GetNumberOfValues(); vtkm::Id arraySize = values.GetNumberOfValues();
LowerBoundsKernel<typename vtkm::cont::ArrayHandle<T, CIn>::template ExecutionTypes< auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
DeviceAdapterTag>::PortalConst, auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
typename vtkm::cont::ArrayHandle<T, CVal>::template ExecutionTypes< auto outputPortal = output.PrepareForOutput(arraySize, DeviceAdapterTag());
DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<vtkm::Id, COut>::template ExecutionTypes< LowerBoundsKernel<decltype(inputPortal), decltype(valuesPortal), decltype(outputPortal)> kernel(
DeviceAdapterTag>::Portal> inputPortal, valuesPortal, outputPortal);
kernel(input.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(arraySize, DeviceAdapterTag()));
DerivedAlgorithm::Schedule(kernel, arraySize); DerivedAlgorithm::Schedule(kernel, arraySize);
} }
@ -276,18 +252,15 @@ public:
{ {
vtkm::Id arraySize = values.GetNumberOfValues(); vtkm::Id arraySize = values.GetNumberOfValues();
LowerBoundsComparisonKernel< auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
typename vtkm::cont::ArrayHandle<T, auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
CIn>::template ExecutionTypes<DeviceAdapterTag>::PortalConst, auto outputPortal = output.PrepareForOutput(arraySize, DeviceAdapterTag());
typename vtkm::cont::ArrayHandle<T, CVal>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst, LowerBoundsComparisonKernel<decltype(inputPortal),
typename vtkm::cont::ArrayHandle<vtkm::Id, decltype(valuesPortal),
COut>::template ExecutionTypes<DeviceAdapterTag>::Portal, decltype(outputPortal),
BinaryCompare> BinaryCompare>
kernel(input.PrepareForInput(DeviceAdapterTag()), kernel(inputPortal, valuesPortal, outputPortal, binary_compare);
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(arraySize, DeviceAdapterTag()),
binary_compare);
DerivedAlgorithm::Schedule(kernel, arraySize); DerivedAlgorithm::Schedule(kernel, arraySize);
} }
@ -322,23 +295,15 @@ public:
// //
//Now that we have an implicit array that is 1/16 the length of full array //Now that we have an implicit array that is 1/16 the length of full array
//we can use scan inclusive to compute the final sum //we can use scan inclusive to compute the final sum
using InputPortalType = auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
typename vtkm::cont::ArrayHandle<T, ReduceKernel<decltype(inputPortal), U, BinaryFunctor> kernel(
CIn>::template ExecutionTypes<DeviceAdapterTag>::PortalConst; inputPortal, initialValue, binary_functor);
using ReduceKernelType = ReduceKernel<InputPortalType, U, BinaryFunctor>;
using ReduceHandleType = vtkm::cont::ArrayHandleImplicit<ReduceKernelType>;
using TempArrayType = vtkm::cont::ArrayHandle<U, vtkm::cont::StorageTagBasic>;
ReduceKernelType kernel(
input.PrepareForInput(DeviceAdapterTag()), initialValue, binary_functor);
vtkm::Id length = (input.GetNumberOfValues() / 16); vtkm::Id length = (input.GetNumberOfValues() / 16);
length += (input.GetNumberOfValues() % 16 == 0) ? 0 : 1; length += (input.GetNumberOfValues() % 16 == 0) ? 0 : 1;
ReduceHandleType reduced = vtkm::cont::make_ArrayHandleImplicit(kernel, length); auto reduced = vtkm::cont::make_ArrayHandleImplicit(kernel, length);
TempArrayType inclusiveScanStorage; vtkm::cont::ArrayHandle<U, vtkm::cont::StorageTagBasic> inclusiveScanStorage;
const U scanResult = const U scanResult =
DerivedAlgorithm::ScanInclusive(reduced, inclusiveScanStorage, binary_functor); DerivedAlgorithm::ScanInclusive(reduced, inclusiveScanStorage, binary_functor);
return scanResult; return scanResult;
@ -372,8 +337,7 @@ public:
if (block == numBlocks - 1) if (block == numBlocks - 1)
numberOfInstances = fullSize - blockSize * block; numberOfInstances = fullSize - blockSize * block;
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, CIn>> streamIn = vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, CIn>> streamIn(
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, CIn>>(
input, block, blockSize, numberOfInstances); input, block, blockSize, numberOfInstances);
if (block == 0) if (block == 0)
@ -415,17 +379,10 @@ public:
vtkm::cont::ArrayHandle<ReduceKeySeriesStates> keystate; vtkm::cont::ArrayHandle<ReduceKeySeriesStates> keystate;
{ {
using InputPortalType = typename vtkm::cont::ArrayHandle<T, KIn>::template ExecutionTypes< auto inputPortal = keys.PrepareForInput(DeviceAdapterTag());
DeviceAdapterTag>::PortalConst; auto keyStatePortal = keystate.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
using KeyStatePortalType = typename vtkm::cont::ArrayHandle< inputPortal, keyStatePortal);
ReduceKeySeriesStates>::template ExecutionTypes<DeviceAdapterTag>::Portal;
InputPortalType inputPortal = keys.PrepareForInput(DeviceAdapterTag());
KeyStatePortalType keyStatePortal =
keystate.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ReduceStencilGeneration<InputPortalType, KeyStatePortalType> kernel(inputPortal,
keyStatePortal);
DerivedAlgorithm::Schedule(kernel, numberOfKeys); DerivedAlgorithm::Schedule(kernel, numberOfKeys);
} }
@ -436,17 +393,11 @@ public:
// the value summed currently, the second being 0 or 1, with 1 being used // the value summed currently, the second being 0 or 1, with 1 being used
// when this is a value of a key we need to write ( END or START_AND_END) // when this is a value of a key we need to write ( END or START_AND_END)
{ {
using ValueInHandleType = vtkm::cont::ArrayHandle<U, VIn>; vtkm::cont::ArrayHandle<ReduceKeySeriesStates> stencil;
using ValueOutHandleType = vtkm::cont::ArrayHandle<U, VOut>; vtkm::cont::ArrayHandle<U, VOut> reducedValues;
using StencilHandleType = vtkm::cont::ArrayHandle<ReduceKeySeriesStates>;
using ZipInHandleType = vtkm::cont::ArrayHandleZip<ValueInHandleType, StencilHandleType>;
using ZipOutHandleType = vtkm::cont::ArrayHandleZip<ValueOutHandleType, StencilHandleType>;
StencilHandleType stencil; auto scanInput = vtkm::cont::make_ArrayHandleZip(values, keystate);
ValueOutHandleType reducedValues; auto scanOutput = vtkm::cont::make_ArrayHandleZip(reducedValues, stencil);
ZipInHandleType scanInput(values, keystate);
ZipOutHandleType scanOutput(reducedValues, stencil);
DerivedAlgorithm::ScanInclusive( DerivedAlgorithm::ScanInclusive(
scanInput, scanOutput, ReduceByKeyAdd<BinaryFunctor>(binary_functor)); scanInput, scanOutput, ReduceByKeyAdd<BinaryFunctor>(binary_functor));
@ -474,28 +425,20 @@ public:
BinaryFunctor binaryFunctor, BinaryFunctor binaryFunctor,
const T& initialValue) const T& initialValue)
{ {
using TempArrayType = vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic>;
using OutputArrayType = vtkm::cont::ArrayHandle<T, COut>;
using SrcPortalType =
typename TempArrayType::template ExecutionTypes<DeviceAdapterTag>::PortalConst;
using DestPortalType =
typename OutputArrayType::template ExecutionTypes<DeviceAdapterTag>::Portal;
vtkm::Id numValues = input.GetNumberOfValues(); vtkm::Id numValues = input.GetNumberOfValues();
if (numValues <= 0) if (numValues <= 0)
{ {
return initialValue; return initialValue;
} }
TempArrayType inclusiveScan; vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic> inclusiveScan;
T result = DerivedAlgorithm::ScanInclusive(input, inclusiveScan, binaryFunctor); T result = DerivedAlgorithm::ScanInclusive(input, inclusiveScan, binaryFunctor);
InclusiveToExclusiveKernel<SrcPortalType, DestPortalType, BinaryFunctor> inclusiveToExclusive( auto inputPortal = inclusiveScan.PrepareForInput(DeviceAdapterTag());
inclusiveScan.PrepareForInput(DeviceAdapterTag()), auto outputPortal = output.PrepareForOutput(numValues, DeviceAdapterTag());
output.PrepareForOutput(numValues, DeviceAdapterTag()),
binaryFunctor, InclusiveToExclusiveKernel<decltype(inputPortal), decltype(outputPortal), BinaryFunctor>
initialValue); inclusiveToExclusive(inputPortal, outputPortal, binaryFunctor, initialValue);
DerivedAlgorithm::Schedule(inclusiveToExclusive, numValues); DerivedAlgorithm::Schedule(inclusiveToExclusive, numValues);
@ -542,39 +485,22 @@ public:
vtkm::cont::ArrayHandle<ReduceKeySeriesStates> keystate; vtkm::cont::ArrayHandle<ReduceKeySeriesStates> keystate;
{ {
using InputPortalType = typename vtkm::cont::ArrayHandle<T, KIn>::template ExecutionTypes< auto inputPortal = keys.PrepareForInput(DeviceAdapterTag());
DeviceAdapterTag>::PortalConst; auto keyStatePortal = keystate.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
using KeyStatePortalType = typename vtkm::cont::ArrayHandle< inputPortal, keyStatePortal);
ReduceKeySeriesStates>::template ExecutionTypes<DeviceAdapterTag>::Portal;
InputPortalType inputPortal = keys.PrepareForInput(DeviceAdapterTag());
KeyStatePortalType keyStatePortal =
keystate.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ReduceStencilGeneration<InputPortalType, KeyStatePortalType> kernel(inputPortal,
keyStatePortal);
DerivedAlgorithm::Schedule(kernel, numberOfKeys); DerivedAlgorithm::Schedule(kernel, numberOfKeys);
} }
// 2. Shift input and initialize elements at head flags position to initValue // 2. Shift input and initialize elements at head flags position to initValue
using TempArrayType = typename vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic>; vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic> temp;
using TempPortalType =
typename vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic>::template ExecutionTypes<
DeviceAdapterTag>::Portal;
TempArrayType temp;
{ {
using InputPortalType = typename vtkm::cont::ArrayHandle<T, KIn>::template ExecutionTypes< auto inputPortal = values.PrepareForInput(DeviceAdapterTag());
DeviceAdapterTag>::PortalConst; auto keyStatePortal = keystate.PrepareForInput(DeviceAdapterTag());
auto tempPortal = temp.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
using KeyStatePortalType = typename vtkm::cont::ArrayHandle< ShiftCopyAndInit<U, decltype(inputPortal), decltype(keyStatePortal), decltype(tempPortal)>
ReduceKeySeriesStates>::template ExecutionTypes<DeviceAdapterTag>::PortalConst; kernel(inputPortal, keyStatePortal, tempPortal, initialValue);
InputPortalType inputPortal = values.PrepareForInput(DeviceAdapterTag());
KeyStatePortalType keyStatePortal = keystate.PrepareForInput(DeviceAdapterTag());
TempPortalType tempPortal = temp.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ShiftCopyAndInit<U, InputPortalType, KeyStatePortalType, TempPortalType> kernel(
inputPortal, keyStatePortal, tempPortal, initialValue);
DerivedAlgorithm::Schedule(kernel, numberOfKeys); DerivedAlgorithm::Schedule(kernel, numberOfKeys);
} }
// 3. Perform a ScanInclusiveByKey // 3. Perform a ScanInclusiveByKey
@ -620,12 +546,10 @@ public:
if (block == numBlocks - 1) if (block == numBlocks - 1)
numberOfInstances = fullSize - blockSize * block; numberOfInstances = fullSize - blockSize * block;
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, CIn>> streamIn = vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, CIn>> streamIn(
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, CIn>>(
input, block, blockSize, numberOfInstances); input, block, blockSize, numberOfInstances);
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, COut>> streamOut = vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, COut>> streamOut(
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, COut>>(
output, block, blockSize, numberOfInstances); output, block, blockSize, numberOfInstances);
if (block == 0) if (block == 0)
@ -659,11 +583,6 @@ public:
vtkm::cont::ArrayHandle<T, COut>& output, vtkm::cont::ArrayHandle<T, COut>& output,
BinaryFunctor binary_functor) BinaryFunctor binary_functor)
{ {
using PortalType =
typename vtkm::cont::ArrayHandle<T, COut>::template ExecutionTypes<DeviceAdapterTag>::Portal;
using ScanKernelType = ScanKernel<PortalType, BinaryFunctor>;
DerivedAlgorithm::Copy(input, output); DerivedAlgorithm::Copy(input, output);
vtkm::Id numValues = output.GetNumberOfValues(); vtkm::Id numValues = output.GetNumberOfValues();
@ -672,7 +591,9 @@ public:
return vtkm::TypeTraits<T>::ZeroInitialization(); return vtkm::TypeTraits<T>::ZeroInitialization();
} }
PortalType portal = output.PrepareForInPlace(DeviceAdapterTag()); auto portal = output.PrepareForInPlace(DeviceAdapterTag());
using ScanKernelType = ScanKernel<decltype(portal), BinaryFunctor>;
vtkm::Id stride; vtkm::Id stride;
for (stride = 2; stride - 1 < numValues; stride *= 2) for (stride = 2; stride - 1 < numValues; stride *= 2)
@ -720,17 +641,10 @@ public:
vtkm::cont::ArrayHandle<ReduceKeySeriesStates> keystate; vtkm::cont::ArrayHandle<ReduceKeySeriesStates> keystate;
{ {
using InputPortalType = typename vtkm::cont::ArrayHandle<T, KIn>::template ExecutionTypes< auto inputPortal = keys.PrepareForInput(DeviceAdapterTag());
DeviceAdapterTag>::PortalConst; auto keyStatePortal = keystate.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
using KeyStatePortalType = typename vtkm::cont::ArrayHandle< inputPortal, keyStatePortal);
ReduceKeySeriesStates>::template ExecutionTypes<DeviceAdapterTag>::Portal;
InputPortalType inputPortal = keys.PrepareForInput(DeviceAdapterTag());
KeyStatePortalType keyStatePortal =
keystate.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ReduceStencilGeneration<InputPortalType, KeyStatePortalType> kernel(inputPortal,
keyStatePortal);
DerivedAlgorithm::Schedule(kernel, numberOfKeys); DerivedAlgorithm::Schedule(kernel, numberOfKeys);
} }
@ -741,18 +655,10 @@ public:
// the value summed currently, the second being 0 or 1, with 1 being used // the value summed currently, the second being 0 or 1, with 1 being used
// when this is a value of a key we need to write ( END or START_AND_END) // when this is a value of a key we need to write ( END or START_AND_END)
{ {
using ValueInHandleType = vtkm::cont::ArrayHandle<U, VIn>; vtkm::cont::ArrayHandle<U, VOut> reducedValues;
using ValueOutHandleType = vtkm::cont::ArrayHandle<U, VOut>; vtkm::cont::ArrayHandle<ReduceKeySeriesStates> stencil;
using StencilHandleType = vtkm::cont::ArrayHandle<ReduceKeySeriesStates>; auto scanInput = vtkm::cont::make_ArrayHandleZip(values, keystate);
using ZipInHandleType = vtkm::cont::ArrayHandleZip<ValueInHandleType, StencilHandleType>; auto scanOutput = vtkm::cont::make_ArrayHandleZip(reducedValues, stencil);
using ZipOutHandleType = vtkm::cont::ArrayHandleZip<ValueOutHandleType, StencilHandleType>;
StencilHandleType stencil;
ValueOutHandleType reducedValues;
ZipInHandleType scanInput(values, keystate);
ZipOutHandleType scanOutput(reducedValues, stencil);
DerivedAlgorithm::ScanInclusive( DerivedAlgorithm::ScanInclusive(
scanInput, scanOutput, ReduceByKeyAdd<BinaryFunctor>(binary_functor)); scanInput, scanOutput, ReduceByKeyAdd<BinaryFunctor>(binary_functor));
@ -768,17 +674,11 @@ public:
VTKM_CONT static void Sort(vtkm::cont::ArrayHandle<T, Storage>& values, VTKM_CONT static void Sort(vtkm::cont::ArrayHandle<T, Storage>& values,
BinaryCompare binary_compare) BinaryCompare binary_compare)
{ {
using ArrayType = typename vtkm::cont::ArrayHandle<T, Storage>;
using PortalType = typename ArrayType::template ExecutionTypes<DeviceAdapterTag>::Portal;
vtkm::Id numValues = values.GetNumberOfValues(); vtkm::Id numValues = values.GetNumberOfValues();
if (numValues < 2) if (numValues < 2)
{ {
return; return;
} }
PortalType portal = values.PrepareForInPlace(DeviceAdapterTag());
vtkm::Id numThreads = 1; vtkm::Id numThreads = 1;
while (numThreads < numValues) while (numThreads < numValues)
{ {
@ -786,8 +686,9 @@ public:
} }
numThreads /= 2; numThreads /= 2;
using MergeKernel = BitonicSortMergeKernel<PortalType, BinaryCompare>; auto portal = values.PrepareForInPlace(DeviceAdapterTag());
using CrossoverKernel = BitonicSortCrossoverKernel<PortalType, BinaryCompare>; using MergeKernel = BitonicSortMergeKernel<decltype(portal), BinaryCompare>;
using CrossoverKernel = BitonicSortCrossoverKernel<decltype(portal), BinaryCompare>;
for (vtkm::Id crossoverSize = 1; crossoverSize < numValues; crossoverSize *= 2) for (vtkm::Id crossoverSize = 1; crossoverSize < numValues; crossoverSize *= 2)
{ {
@ -816,12 +717,7 @@ public:
//combine the keys and values into a ZipArrayHandle //combine the keys and values into a ZipArrayHandle
//we than need to specify a custom compare function wrapper //we than need to specify a custom compare function wrapper
//that only checks for key side of the pair, using a custom compare functor. //that only checks for key side of the pair, using a custom compare functor.
using KeyType = vtkm::cont::ArrayHandle<T, StorageT>; auto zipHandle = vtkm::cont::make_ArrayHandleZip(keys, values);
;
using ValueType = vtkm::cont::ArrayHandle<U, StorageU>;
using ZipHandleType = vtkm::cont::ArrayHandleZip<KeyType, ValueType>;
ZipHandleType zipHandle = vtkm::cont::make_ArrayHandleZip(keys, values);
DerivedAlgorithm::Sort(zipHandle, internal::KeyCompare<T, U>()); DerivedAlgorithm::Sort(zipHandle, internal::KeyCompare<T, U>());
} }
@ -834,12 +730,7 @@ public:
//we than need to specify a custom compare function wrapper //we than need to specify a custom compare function wrapper
//that only checks for key side of the pair, using the custom compare //that only checks for key side of the pair, using the custom compare
//functor that the user passed in //functor that the user passed in
using KeyType = vtkm::cont::ArrayHandle<T, StorageT>; auto zipHandle = vtkm::cont::make_ArrayHandleZip(keys, values);
;
using ValueType = vtkm::cont::ArrayHandle<U, StorageU>;
using ZipHandleType = vtkm::cont::ArrayHandleZip<KeyType, ValueType>;
ZipHandleType zipHandle = vtkm::cont::make_ArrayHandleZip(keys, values);
DerivedAlgorithm::Sort(zipHandle, internal::KeyCompare<T, U, BinaryCompare>(binary_compare)); DerivedAlgorithm::Sort(zipHandle, internal::KeyCompare<T, U, BinaryCompare>(binary_compare));
} }
@ -861,15 +752,10 @@ public:
using WrappedBOpType = internal::WrappedBinaryOperator<bool, BinaryCompare>; using WrappedBOpType = internal::WrappedBinaryOperator<bool, BinaryCompare>;
WrappedBOpType wrappedCompare(binary_compare); WrappedBOpType wrappedCompare(binary_compare);
ClassifyUniqueComparisonKernel< auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
typename vtkm::cont::ArrayHandle<T, Storage>::template ExecutionTypes< auto stencilPortal = stencilArray.PrepareForOutput(inputSize, DeviceAdapterTag());
DeviceAdapterTag>::PortalConst, ClassifyUniqueComparisonKernel<decltype(valuesPortal), decltype(stencilPortal), WrappedBOpType>
typename vtkm::cont::ArrayHandle<vtkm::Id, vtkm::cont::StorageTagBasic>:: classifyKernel(valuesPortal, stencilPortal, wrappedCompare);
template ExecutionTypes<DeviceAdapterTag>::Portal,
WrappedBOpType>
classifyKernel(values.PrepareForInput(DeviceAdapterTag()),
stencilArray.PrepareForOutput(inputSize, DeviceAdapterTag()),
wrappedCompare);
DerivedAlgorithm::Schedule(classifyKernel, inputSize); DerivedAlgorithm::Schedule(classifyKernel, inputSize);
@ -890,16 +776,12 @@ public:
{ {
vtkm::Id arraySize = values.GetNumberOfValues(); vtkm::Id arraySize = values.GetNumberOfValues();
UpperBoundsKernel<typename vtkm::cont::ArrayHandle<T, CIn>::template ExecutionTypes< auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
DeviceAdapterTag>::PortalConst, auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
typename vtkm::cont::ArrayHandle<T, CVal>::template ExecutionTypes< auto outputPortal = output.PrepareForOutput(arraySize, DeviceAdapterTag());
DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<vtkm::Id, COut>::template ExecutionTypes<
DeviceAdapterTag>::Portal>
kernel(input.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(arraySize, DeviceAdapterTag()));
UpperBoundsKernel<decltype(inputPortal), decltype(valuesPortal), decltype(outputPortal)> kernel(
inputPortal, valuesPortal, outputPortal);
DerivedAlgorithm::Schedule(kernel, arraySize); DerivedAlgorithm::Schedule(kernel, arraySize);
} }
@ -911,18 +793,15 @@ public:
{ {
vtkm::Id arraySize = values.GetNumberOfValues(); vtkm::Id arraySize = values.GetNumberOfValues();
UpperBoundsKernelComparisonKernel< auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
typename vtkm::cont::ArrayHandle<T, auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
CIn>::template ExecutionTypes<DeviceAdapterTag>::PortalConst, auto outputPortal = output.PrepareForOutput(arraySize, DeviceAdapterTag());
typename vtkm::cont::ArrayHandle<T, CVal>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst, UpperBoundsKernelComparisonKernel<decltype(inputPortal),
typename vtkm::cont::ArrayHandle<vtkm::Id, decltype(valuesPortal),
COut>::template ExecutionTypes<DeviceAdapterTag>::Portal, decltype(outputPortal),
BinaryCompare> BinaryCompare>
kernel(input.PrepareForInput(DeviceAdapterTag()), kernel(inputPortal, valuesPortal, outputPortal, binary_compare);
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(arraySize, DeviceAdapterTag()),
binary_compare);
DerivedAlgorithm::Schedule(kernel, arraySize); DerivedAlgorithm::Schedule(kernel, arraySize);
} }

@ -396,7 +396,7 @@ struct LowerBoundsKernel
using InputIteratorsType = vtkm::cont::ArrayPortalToIterators<InputPortalType>; using InputIteratorsType = vtkm::cont::ArrayPortalToIterators<InputPortalType>;
InputIteratorsType inputIterators(this->InputPortal); InputIteratorsType inputIterators(this->InputPortal);
typename InputIteratorsType::IteratorType resultPos = std::lower_bound( auto resultPos = std::lower_bound(
inputIterators.GetBegin(), inputIterators.GetEnd(), this->ValuesPortal.Get(index)); inputIterators.GetBegin(), inputIterators.GetEnd(), this->ValuesPortal.Get(index));
vtkm::Id resultIndex = vtkm::Id resultIndex =
@ -444,8 +444,7 @@ struct LowerBoundsComparisonKernel
using InputIteratorsType = vtkm::cont::ArrayPortalToIterators<InputPortalType>; using InputIteratorsType = vtkm::cont::ArrayPortalToIterators<InputPortalType>;
InputIteratorsType inputIterators(this->InputPortal); InputIteratorsType inputIterators(this->InputPortal);
typename InputIteratorsType::IteratorType resultPos = auto resultPos = std::lower_bound(inputIterators.GetBegin(),
std::lower_bound(inputIterators.GetBegin(),
inputIterators.GetEnd(), inputIterators.GetEnd(),
this->ValuesPortal.Get(index), this->ValuesPortal.Get(index),
this->CompareFunctor); this->CompareFunctor);
@ -748,7 +747,7 @@ struct UpperBoundsKernel
using InputIteratorsType = vtkm::cont::ArrayPortalToIterators<InputPortalType>; using InputIteratorsType = vtkm::cont::ArrayPortalToIterators<InputPortalType>;
InputIteratorsType inputIterators(this->InputPortal); InputIteratorsType inputIterators(this->InputPortal);
typename InputIteratorsType::IteratorType resultPos = std::upper_bound( auto resultPos = std::upper_bound(
inputIterators.GetBegin(), inputIterators.GetEnd(), this->ValuesPortal.Get(index)); inputIterators.GetBegin(), inputIterators.GetEnd(), this->ValuesPortal.Get(index));
vtkm::Id resultIndex = vtkm::Id resultIndex =
@ -796,8 +795,7 @@ struct UpperBoundsKernelComparisonKernel
using InputIteratorsType = vtkm::cont::ArrayPortalToIterators<InputPortalType>; using InputIteratorsType = vtkm::cont::ArrayPortalToIterators<InputPortalType>;
InputIteratorsType inputIterators(this->InputPortal); InputIteratorsType inputIterators(this->InputPortal);
typename InputIteratorsType::IteratorType resultPos = auto resultPos = std::upper_bound(inputIterators.GetBegin(),
std::upper_bound(inputIterators.GetBegin(),
inputIterators.GetEnd(), inputIterators.GetEnd(),
this->ValuesPortal.Get(index), this->ValuesPortal.Get(index),
this->CompareFunctor); this->CompareFunctor);

@ -61,11 +61,8 @@ public:
U initialValue, U initialValue,
BinaryFunctor binary_functor) BinaryFunctor binary_functor)
{ {
typedef typename vtkm::cont::ArrayHandle<T, CIn>::template ExecutionTypes<Device>::PortalConst
PortalIn;
internal::WrappedBinaryOperator<U, BinaryFunctor> wrappedOp(binary_functor); internal::WrappedBinaryOperator<U, BinaryFunctor> wrappedOp(binary_functor);
PortalIn inputPortal = input.PrepareForInput(Device()); auto inputPortal = input.PrepareForInput(Device());
return std::accumulate(vtkm::cont::ArrayPortalToIteratorBegin(inputPortal), return std::accumulate(vtkm::cont::ArrayPortalToIteratorBegin(inputPortal),
vtkm::cont::ArrayPortalToIteratorEnd(inputPortal), vtkm::cont::ArrayPortalToIteratorEnd(inputPortal),
initialValue, initialValue,
@ -85,22 +82,12 @@ public:
vtkm::cont::ArrayHandle<U, VOut>& values_output, vtkm::cont::ArrayHandle<U, VOut>& values_output,
BinaryFunctor binary_functor) BinaryFunctor binary_functor)
{ {
typedef typename vtkm::cont::ArrayHandle<T, KIn>::template ExecutionTypes<Device>::PortalConst auto keysPortalIn = keys.PrepareForInput(Device());
PortalKIn; auto valuesPortalIn = values.PrepareForInput(Device());
typedef typename vtkm::cont::ArrayHandle<U, VIn>::template ExecutionTypes<Device>::PortalConst
PortalVIn;
typedef
typename vtkm::cont::ArrayHandle<T, KOut>::template ExecutionTypes<Device>::Portal PortalKOut;
typedef
typename vtkm::cont::ArrayHandle<U, VOut>::template ExecutionTypes<Device>::Portal PortalVOut;
PortalKIn keysPortalIn = keys.PrepareForInput(Device());
PortalVIn valuesPortalIn = values.PrepareForInput(Device());
const vtkm::Id numberOfKeys = keys.GetNumberOfValues(); const vtkm::Id numberOfKeys = keys.GetNumberOfValues();
PortalKOut keysPortalOut = keys_output.PrepareForOutput(numberOfKeys, Device()); auto keysPortalOut = keys_output.PrepareForOutput(numberOfKeys, Device());
PortalVOut valuesPortalOut = values_output.PrepareForOutput(numberOfKeys, Device()); auto valuesPortalOut = values_output.PrepareForOutput(numberOfKeys, Device());
vtkm::Id writePos = 0; vtkm::Id writePos = 0;
vtkm::Id readPos = 0; vtkm::Id readPos = 0;
@ -141,15 +128,10 @@ public:
VTKM_CONT static T ScanInclusive(const vtkm::cont::ArrayHandle<T, CIn>& input, VTKM_CONT static T ScanInclusive(const vtkm::cont::ArrayHandle<T, CIn>& input,
vtkm::cont::ArrayHandle<T, COut>& output) vtkm::cont::ArrayHandle<T, COut>& output)
{ {
typedef
typename vtkm::cont::ArrayHandle<T, COut>::template ExecutionTypes<Device>::Portal PortalOut;
typedef typename vtkm::cont::ArrayHandle<T, CIn>::template ExecutionTypes<Device>::PortalConst
PortalIn;
vtkm::Id numberOfValues = input.GetNumberOfValues(); vtkm::Id numberOfValues = input.GetNumberOfValues();
PortalIn inputPortal = input.PrepareForInput(Device()); auto inputPortal = input.PrepareForInput(Device());
PortalOut outputPortal = output.PrepareForOutput(numberOfValues, Device()); auto outputPortal = output.PrepareForOutput(numberOfValues, Device());
if (numberOfValues <= 0) if (numberOfValues <= 0)
{ {
@ -169,17 +151,12 @@ public:
vtkm::cont::ArrayHandle<T, COut>& output, vtkm::cont::ArrayHandle<T, COut>& output,
BinaryFunctor binary_functor) BinaryFunctor binary_functor)
{ {
typedef
typename vtkm::cont::ArrayHandle<T, COut>::template ExecutionTypes<Device>::Portal PortalOut;
typedef typename vtkm::cont::ArrayHandle<T, CIn>::template ExecutionTypes<Device>::PortalConst
PortalIn;
internal::WrappedBinaryOperator<T, BinaryFunctor> wrappedBinaryOp(binary_functor); internal::WrappedBinaryOperator<T, BinaryFunctor> wrappedBinaryOp(binary_functor);
vtkm::Id numberOfValues = input.GetNumberOfValues(); vtkm::Id numberOfValues = input.GetNumberOfValues();
PortalIn inputPortal = input.PrepareForInput(Device()); auto inputPortal = input.PrepareForInput(Device());
PortalOut outputPortal = output.PrepareForOutput(numberOfValues, Device()); auto outputPortal = output.PrepareForOutput(numberOfValues, Device());
if (numberOfValues <= 0) if (numberOfValues <= 0)
{ {
@ -201,17 +178,12 @@ public:
BinaryFunctor binaryFunctor, BinaryFunctor binaryFunctor,
const T& initialValue) const T& initialValue)
{ {
typedef
typename vtkm::cont::ArrayHandle<T, COut>::template ExecutionTypes<Device>::Portal PortalOut;
typedef typename vtkm::cont::ArrayHandle<T, CIn>::template ExecutionTypes<Device>::PortalConst
PortalIn;
internal::WrappedBinaryOperator<T, BinaryFunctor> wrappedBinaryOp(binaryFunctor); internal::WrappedBinaryOperator<T, BinaryFunctor> wrappedBinaryOp(binaryFunctor);
vtkm::Id numberOfValues = input.GetNumberOfValues(); vtkm::Id numberOfValues = input.GetNumberOfValues();
PortalIn inputPortal = input.PrepareForInput(Device()); auto inputPortal = input.PrepareForInput(Device());
PortalOut outputPortal = output.PrepareForOutput(numberOfValues, Device()); auto outputPortal = output.PrepareForOutput(numberOfValues, Device());
if (numberOfValues <= 0) if (numberOfValues <= 0)
{ {
@ -277,21 +249,12 @@ private:
vtkm::cont::ArrayHandle<I, StorageI>& index, vtkm::cont::ArrayHandle<I, StorageI>& index,
vtkm::cont::ArrayHandle<Vout, StorageVout>& values_out) vtkm::cont::ArrayHandle<Vout, StorageVout>& values_out)
{ {
typedef typename vtkm::cont::ArrayHandle<Vin, StorageVin>::template ExecutionTypes<
Device>::PortalConst PortalVIn;
typedef
typename vtkm::cont::ArrayHandle<I, StorageI>::template ExecutionTypes<Device>::PortalConst
PortalI;
typedef
typename vtkm::cont::ArrayHandle<Vout, StorageVout>::template ExecutionTypes<Device>::Portal
PortalVout;
const vtkm::Id n = values.GetNumberOfValues(); const vtkm::Id n = values.GetNumberOfValues();
VTKM_ASSERT(n == index.GetNumberOfValues()); VTKM_ASSERT(n == index.GetNumberOfValues());
PortalVIn valuesPortal = values.PrepareForInput(Device()); auto valuesPortal = values.PrepareForInput(Device());
PortalI indexPortal = index.PrepareForInput(Device()); auto indexPortal = index.PrepareForInput(Device());
PortalVout valuesOutPortal = values_out.PrepareForOutput(n, Device()); auto valuesOutPortal = values_out.PrepareForOutput(n, Device());
for (vtkm::Id i = 0; i < n; i++) for (vtkm::Id i = 0; i < n; i++)
{ {
@ -310,12 +273,7 @@ private:
//we than need to specify a custom compare function wrapper //we than need to specify a custom compare function wrapper
//that only checks for key side of the pair, using the custom compare //that only checks for key side of the pair, using the custom compare
//functor that the user passed in //functor that the user passed in
using KeyType = vtkm::cont::ArrayHandle<T, StorageT>; auto zipHandle = vtkm::cont::make_ArrayHandleZip(keys, values);
;
using ValueType = vtkm::cont::ArrayHandle<U, StorageU>;
typedef vtkm::cont::ArrayHandleZip<KeyType, ValueType> ZipHandleType;
ZipHandleType zipHandle = vtkm::cont::make_ArrayHandleZip(keys, values);
Sort(zipHandle, internal::KeyCompare<T, U, BinaryCompare>(binary_compare)); Sort(zipHandle, internal::KeyCompare<T, U, BinaryCompare>(binary_compare));
} }
@ -337,11 +295,8 @@ public:
{ {
/// More efficient sort: /// More efficient sort:
/// Move value indexes when sorting and reorder the value array at last /// Move value indexes when sorting and reorder the value array at last
using ValueType = vtkm::cont::ArrayHandle<U, StorageU>; vtkm::cont::ArrayHandle<vtkm::Id> indexArray;
using IndexType = vtkm::cont::ArrayHandle<vtkm::Id>; vtkm::cont::ArrayHandle<U, StorageU> valuesScattered;
IndexType indexArray;
ValueType valuesScattered;
Copy(ArrayHandleIndex(keys.GetNumberOfValues()), indexArray); Copy(ArrayHandleIndex(keys.GetNumberOfValues()), indexArray);
SortByKeyDirect(keys, indexArray, wrappedCompare); SortByKeyDirect(keys, indexArray, wrappedCompare);
@ -364,11 +319,8 @@ public:
VTKM_CONT static void Sort(vtkm::cont::ArrayHandle<T, Storage>& values, VTKM_CONT static void Sort(vtkm::cont::ArrayHandle<T, Storage>& values,
BinaryCompare binary_compare) BinaryCompare binary_compare)
{ {
typedef typename vtkm::cont::ArrayHandle<T, Storage>::template ExecutionTypes<Device>::Portal auto arrayPortal = values.PrepareForInPlace(Device());
PortalType; vtkm::cont::ArrayPortalToIterators<decltype(arrayPortal)> iterators(arrayPortal);
PortalType arrayPortal = values.PrepareForInPlace(Device());
vtkm::cont::ArrayPortalToIterators<PortalType> iterators(arrayPortal);
internal::WrappedBinaryOperator<bool, BinaryCompare> wrappedCompare(binary_compare); internal::WrappedBinaryOperator<bool, BinaryCompare> wrappedCompare(binary_compare);
std::sort(iterators.GetBegin(), iterators.GetEnd(), wrappedCompare); std::sort(iterators.GetBegin(), iterators.GetEnd(), wrappedCompare);