Reduce the amount of typedef statements in DeviceAdapters

By using the auto keyword and decltype we can reduce the number of
complex typedefs that exist when writing device adapter algorithms.
The goal being that it is easier for developers to see the actual
algorithms being implemented, by reducing the amount of template
'noise'.
This commit is contained in:
Robert Maynard 2017-08-10 16:13:55 -04:00
parent e095831199
commit 89f439999a
4 changed files with 195 additions and 382 deletions

@ -216,18 +216,16 @@ private:
OutputPortal output,
UnaryPredicate unary_predicate)
{
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType IteratorType;
using ValueType = typename StencilPortal::ValueType;
IteratorType outputBegin = IteratorBegin(output);
typedef typename StencilPortal::ValueType ValueType;
auto outputBegin = IteratorBegin(output);
vtkm::exec::cuda::internal::WrappedUnaryPredicate<ValueType, UnaryPredicate> up(
unary_predicate);
try
{
IteratorType newLast = ::thrust::copy_if(
auto newLast = ::thrust::copy_if(
thrust::cuda::par, valuesBegin, valuesEnd, IteratorBegin(stencil), outputBegin, up);
return static_cast<vtkm::Id>(::thrust::distance(outputBegin, newLast));
}
@ -273,7 +271,7 @@ private:
const ValuesPortal& values,
const OutputPortal& output)
{
typedef typename ValuesPortal::ValueType ValueType;
using ValueType = typename ValuesPortal::ValueType;
LowerBoundsPortal(input, values, output, ::thrust::less<ValueType>());
}
@ -281,7 +279,7 @@ private:
VTKM_CONT static void LowerBoundsPortal(const InputPortal& input,
const OutputPortal& values_output)
{
typedef typename InputPortal::ValueType ValueType;
using ValueType = typename InputPortal::ValueType;
LowerBoundsPortal(input, values_output, values_output, ::thrust::less<ValueType>());
}
@ -291,7 +289,7 @@ private:
const OutputPortal& output,
BinaryCompare binary_compare)
{
typedef typename InputPortal::ValueType ValueType;
using ValueType = typename InputPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryPredicate<ValueType, BinaryCompare> bop(
binary_compare);
@ -388,17 +386,14 @@ private:
const ValueOutputPortal& values_output,
BinaryFunctor binary_functor)
{
typedef typename detail::IteratorTraits<KeysOutputPortal>::IteratorType KeysIteratorType;
typedef typename detail::IteratorTraits<ValueOutputPortal>::IteratorType ValuesIteratorType;
auto keys_out_begin = IteratorBegin(keys_output);
auto values_out_begin = IteratorBegin(values_output);
KeysIteratorType keys_out_begin = IteratorBegin(keys_output);
ValuesIteratorType values_out_begin = IteratorBegin(values_output);
::thrust::pair<KeysIteratorType, ValuesIteratorType> result_iterators;
::thrust::pair<decltype(keys_out_begin), decltype(values_out_begin)> result_iterators;
::thrust::equal_to<typename KeysPortal::ValueType> binaryPredicate;
typedef typename ValuesPortal::ValueType ValueType;
using ValueType = typename ValuesPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, BinaryFunctor> bop(binary_functor);
try
@ -424,7 +419,7 @@ private:
VTKM_CONT static typename InputPortal::ValueType ScanExclusivePortal(const InputPortal& input,
const OutputPortal& output)
{
typedef typename OutputPortal::ValueType ValueType;
using ValueType = typename OutputPortal::ValueType;
return ScanExclusivePortal(input,
output,
@ -441,7 +436,7 @@ private:
{
// Use iterator to get value so that thrust device_ptr has chance to handle
// data on device.
typedef typename OutputPortal::ValueType ValueType;
using ValueType = typename OutputPortal::ValueType;
//we have size three so that we can store the origin end value, the
//new end value, and the sum of those two
@ -456,13 +451,12 @@ private:
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, BinaryFunctor> bop(binaryOp);
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType IteratorType;
IteratorType end = ::thrust::exclusive_scan(thrust::cuda::par,
IteratorBegin(input),
IteratorEnd(input),
IteratorBegin(output),
initialValue,
bop);
auto end = ::thrust::exclusive_scan(thrust::cuda::par,
IteratorBegin(input),
IteratorEnd(input),
IteratorBegin(output),
initialValue,
bop);
//Store the new value for the end of the array. This is done because
//with items such as the transpose array it is unsafe to pass the
@ -483,7 +477,7 @@ private:
VTKM_CONT static typename InputPortal::ValueType ScanInclusivePortal(const InputPortal& input,
const OutputPortal& output)
{
typedef typename OutputPortal::ValueType ValueType;
using ValueType = typename OutputPortal::ValueType;
return ScanInclusivePortal(input, output, ::thrust::plus<ValueType>());
}
@ -492,14 +486,12 @@ private:
const OutputPortal& output,
BinaryFunctor binary_functor)
{
typedef typename OutputPortal::ValueType ValueType;
using ValueType = typename OutputPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, BinaryFunctor> bop(binary_functor);
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType IteratorType;
try
{
IteratorType end = ::thrust::inclusive_scan(
auto end = ::thrust::inclusive_scan(
thrust::cuda::par, IteratorBegin(input), IteratorEnd(input), IteratorBegin(output), bop);
return *(end - 1);
}
@ -518,7 +510,7 @@ private:
const OutputPortal& output)
{
using KeyType = typename KeysPortal::ValueType;
typedef typename OutputPortal::ValueType ValueType;
using ValueType = typename OutputPortal::ValueType;
ScanInclusiveByKeyPortal(
keys, values, output, ::thrust::equal_to<KeyType>(), ::thrust::plus<ValueType>());
}
@ -534,14 +526,13 @@ private:
BinaryPredicate binary_predicate,
AssociativeOperator binary_operator)
{
typedef typename KeysPortal::ValueType KeyType;
using KeyType = typename KeysPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<KeyType, BinaryPredicate> bpred(
binary_predicate);
typedef typename OutputPortal::ValueType ValueType;
using ValueType = typename OutputPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, AssociativeOperator> bop(
binary_operator);
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType IteratorType;
try
{
::thrust::inclusive_scan_by_key(thrust::cuda::par,
@ -564,7 +555,7 @@ private:
const OutputPortal& output)
{
using KeyType = typename KeysPortal::ValueType;
typedef typename OutputPortal::ValueType ValueType;
using ValueType = typename OutputPortal::ValueType;
ScanExclusiveByKeyPortal(keys,
values,
output,
@ -586,14 +577,12 @@ private:
BinaryPredicate binary_predicate,
AssociativeOperator binary_operator)
{
typedef typename KeysPortal::ValueType KeyType;
using KeyType = typename KeysPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<KeyType, BinaryPredicate> bpred(
binary_predicate);
typedef typename OutputPortal::ValueType ValueType;
using ValueType = typename OutputPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryOperator<ValueType, AssociativeOperator> bop(
binary_operator);
typedef typename detail::IteratorTraits<OutputPortal>::IteratorType IteratorType;
try
{
::thrust::exclusive_scan_by_key(thrust::cuda::par,
@ -614,14 +603,14 @@ private:
template <class ValuesPortal>
VTKM_CONT static void SortPortal(const ValuesPortal& values)
{
typedef typename ValuesPortal::ValueType ValueType;
using ValueType = typename ValuesPortal::ValueType;
SortPortal(values, ::thrust::less<ValueType>());
}
template <class ValuesPortal, class BinaryCompare>
VTKM_CONT static void SortPortal(const ValuesPortal& values, BinaryCompare binary_compare)
{
typedef typename ValuesPortal::ValueType ValueType;
using ValueType = typename ValuesPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryPredicate<ValueType, BinaryCompare> bop(
binary_compare);
try
@ -646,7 +635,7 @@ private:
const ValuesPortal& values,
BinaryCompare binary_compare)
{
typedef typename KeysPortal::ValueType ValueType;
using ValueType = typename KeysPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryPredicate<ValueType, BinaryCompare> bop(
binary_compare);
try
@ -663,11 +652,10 @@ private:
template <class ValuesPortal>
VTKM_CONT static vtkm::Id UniquePortal(const ValuesPortal values)
{
typedef typename detail::IteratorTraits<ValuesPortal>::IteratorType IteratorType;
try
{
IteratorType begin = IteratorBegin(values);
IteratorType newLast = ::thrust::unique(thrust::cuda::par, begin, IteratorEnd(values));
auto begin = IteratorBegin(values);
auto newLast = ::thrust::unique(thrust::cuda::par, begin, IteratorEnd(values));
return static_cast<vtkm::Id>(::thrust::distance(begin, newLast));
}
catch (...)
@ -680,15 +668,13 @@ private:
template <class ValuesPortal, class BinaryCompare>
VTKM_CONT static vtkm::Id UniquePortal(const ValuesPortal values, BinaryCompare binary_compare)
{
typedef typename detail::IteratorTraits<ValuesPortal>::IteratorType IteratorType;
typedef typename ValuesPortal::ValueType ValueType;
using ValueType = typename ValuesPortal::ValueType;
vtkm::exec::cuda::internal::WrappedBinaryPredicate<ValueType, BinaryCompare> bop(
binary_compare);
try
{
IteratorType begin = IteratorBegin(values);
IteratorType newLast = ::thrust::unique(thrust::cuda::par, begin, IteratorEnd(values), bop);
auto begin = IteratorBegin(values);
auto newLast = ::thrust::unique(thrust::cuda::par, begin, IteratorEnd(values), bop);
return static_cast<vtkm::Id>(::thrust::distance(begin, newLast));
}
catch (...)
@ -949,8 +935,8 @@ public:
//function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place
//use case breaks.
input.PrepareForInput(DeviceAdapterTag());
return ScanExclusivePortal(input.PrepareForInput(DeviceAdapterTag()),
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
return ScanExclusivePortal(inputPortal,
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
}
@ -971,8 +957,8 @@ public:
//function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place
//use case breaks.
input.PrepareForInput(DeviceAdapterTag());
return ScanExclusivePortal(input.PrepareForInput(DeviceAdapterTag()),
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
return ScanExclusivePortal(inputPortal,
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
binary_functor,
initialValue);
@ -993,8 +979,8 @@ public:
//function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place
//use case breaks.
input.PrepareForInput(DeviceAdapterTag());
return ScanInclusivePortal(input.PrepareForInput(DeviceAdapterTag()),
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
return ScanInclusivePortal(inputPortal,
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
}
@ -1014,10 +1000,9 @@ public:
//function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place
//use case breaks.
input.PrepareForInput(DeviceAdapterTag());
return ScanInclusivePortal(input.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
binary_functor);
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
return ScanInclusivePortal(
inputPortal, output.PrepareForOutput(numberOfValues, DeviceAdapterTag()), binary_functor);
}
template <typename T, typename U, typename KIn, typename VIn, typename VOut>
@ -1035,11 +1020,10 @@ public:
//function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place
//use case breaks.
keys.PrepareForInput(DeviceAdapterTag());
values.PrepareForInput(DeviceAdapterTag());
ScanInclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
auto keysPortal = keys.PrepareForInput(DeviceAdapterTag());
auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
ScanInclusiveByKeyPortal(
keysPortal, valuesPortal, output.PrepareForOutput(numberOfValues, DeviceAdapterTag()));
}
template <typename T,
@ -1063,10 +1047,10 @@ public:
//function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place
//use case breaks.
keys.PrepareForInput(DeviceAdapterTag());
values.PrepareForInput(DeviceAdapterTag());
ScanInclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
auto keysPortal = keys.PrepareForInput(DeviceAdapterTag());
auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
ScanInclusiveByKeyPortal(keysPortal,
valuesPortal,
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
::thrust::equal_to<T>(),
binary_functor);
@ -1088,10 +1072,10 @@ public:
//function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place
//use case breaks.
keys.PrepareForInput(DeviceAdapterTag());
values.PrepareForInput(DeviceAdapterTag());
ScanExnclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
auto keysPortal = keys.PrepareForInput(DeviceAdapterTag());
auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
ScanExnclusiveByKeyPortal(keysPortal,
valuesPortal,
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
vtkm::TypeTraits<T>::ZeroInitialization(),
vtkm::Add());
@ -1120,10 +1104,10 @@ public:
//function. The order of execution of parameters of a function is undefined
//so we need to make sure input is called before output, or else in-place
//use case breaks.
keys.PrepareForInput(DeviceAdapterTag());
values.PrepareForInput(DeviceAdapterTag());
ScanExclusiveByKeyPortal(keys.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
auto keysPortal = keys.PrepareForInput(DeviceAdapterTag());
auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
ScanExclusiveByKeyPortal(keysPortal,
valuesPortal,
output.PrepareForOutput(numberOfValues, DeviceAdapterTag()),
initialValue,
::thrust::equal_to<T>(),

@ -102,16 +102,14 @@ private:
template <typename T, class CIn>
VTKM_CONT static T GetExecutionValue(const vtkm::cont::ArrayHandle<T, CIn>& input, vtkm::Id index)
{
using InputArrayType = vtkm::cont::ArrayHandle<T, CIn>;
using OutputArrayType = vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic>;
OutputArrayType output;
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
auto outputPortal = output.PrepareForOutput(1, DeviceAdapterTag());
CopyKernel<typename InputArrayType::template ExecutionTypes<DeviceAdapterTag>::PortalConst,
typename OutputArrayType::template ExecutionTypes<DeviceAdapterTag>::Portal>
kernel(input.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(1, DeviceAdapterTag()),
index);
CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(
inputPortal, outputPortal, index);
DerivedAlgorithm::Schedule(kernel, 1);
@ -125,14 +123,11 @@ public:
VTKM_CONT static void Copy(const vtkm::cont::ArrayHandle<T, CIn>& input,
vtkm::cont::ArrayHandle<U, COut>& output)
{
using CopyKernelType = CopyKernel<
typename vtkm::cont::ArrayHandle<T,
CIn>::template ExecutionTypes<DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<U, COut>::template ExecutionTypes<DeviceAdapterTag>::Portal>;
const vtkm::Id inSize = input.GetNumberOfValues();
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
auto outputPortal = output.PrepareForOutput(inSize, DeviceAdapterTag());
CopyKernelType kernel(input.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(inSize, DeviceAdapterTag()));
CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(inputPortal, outputPortal);
DerivedAlgorithm::Schedule(kernel, inSize);
}
@ -150,35 +145,23 @@ public:
using IndexArrayType = vtkm::cont::ArrayHandle<vtkm::Id, vtkm::cont::StorageTagBasic>;
IndexArrayType indices;
using StencilPortalType =
typename vtkm::cont::ArrayHandle<U, CStencil>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst;
StencilPortalType stencilPortal = stencil.PrepareForInput(DeviceAdapterTag());
auto stencilPortal = stencil.PrepareForInput(DeviceAdapterTag());
auto indexPortal = indices.PrepareForOutput(arrayLength, DeviceAdapterTag());
using IndexPortalType =
typename IndexArrayType::template ExecutionTypes<DeviceAdapterTag>::Portal;
IndexPortalType indexPortal = indices.PrepareForOutput(arrayLength, DeviceAdapterTag());
StencilToIndexFlagKernel<StencilPortalType, IndexPortalType, UnaryPredicate> indexKernel(
stencilPortal, indexPortal, unary_predicate);
StencilToIndexFlagKernel<decltype(stencilPortal), decltype(indexPortal), UnaryPredicate>
indexKernel(stencilPortal, indexPortal, unary_predicate);
DerivedAlgorithm::Schedule(indexKernel, arrayLength);
vtkm::Id outArrayLength = DerivedAlgorithm::ScanExclusive(indices, indices);
using InputPortalType =
typename vtkm::cont::ArrayHandle<T,
CIn>::template ExecutionTypes<DeviceAdapterTag>::PortalConst;
InputPortalType inputPortal = input.PrepareForInput(DeviceAdapterTag());
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
auto outputPortal = output.PrepareForOutput(outArrayLength, DeviceAdapterTag());
using OutputPortalType =
typename vtkm::cont::ArrayHandle<T, COut>::template ExecutionTypes<DeviceAdapterTag>::Portal;
OutputPortalType outputPortal = output.PrepareForOutput(outArrayLength, DeviceAdapterTag());
CopyIfKernel<InputPortalType,
StencilPortalType,
IndexPortalType,
OutputPortalType,
CopyIfKernel<decltype(inputPortal),
decltype(stencilPortal),
decltype(indexPortal),
decltype(outputPortal),
UnaryPredicate>
copyKernel(inputPortal, stencilPortal, indexPortal, outputPortal, unary_predicate);
DerivedAlgorithm::Schedule(copyKernel, arrayLength);
@ -202,11 +185,6 @@ public:
vtkm::cont::ArrayHandle<U, COut>& output,
vtkm::Id outputIndex = 0)
{
using CopyKernel = CopyKernel<
typename vtkm::cont::ArrayHandle<T,
CIn>::template ExecutionTypes<DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<U, COut>::template ExecutionTypes<DeviceAdapterTag>::Portal>;
const vtkm::Id inSize = input.GetNumberOfValues();
if (inputStartIndex < 0 || numberOfElementsToCopy < 0 || outputIndex < 0 ||
inputStartIndex >= inSize)
@ -238,10 +216,11 @@ public:
}
}
CopyKernel kernel(input.PrepareForInput(DeviceAdapterTag()),
output.PrepareForInPlace(DeviceAdapterTag()),
inputStartIndex,
outputIndex);
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
auto outputPortal = output.PrepareForInPlace(DeviceAdapterTag());
CopyKernel<decltype(inputPortal), decltype(outputPortal)> kernel(
inputPortal, outputPortal, inputStartIndex, outputIndex);
DerivedAlgorithm::Schedule(kernel, numberOfElementsToCopy);
return true;
}
@ -255,15 +234,12 @@ public:
{
vtkm::Id arraySize = values.GetNumberOfValues();
LowerBoundsKernel<typename vtkm::cont::ArrayHandle<T, CIn>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<T, CVal>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<vtkm::Id, COut>::template ExecutionTypes<
DeviceAdapterTag>::Portal>
kernel(input.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(arraySize, DeviceAdapterTag()));
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
auto outputPortal = output.PrepareForOutput(arraySize, DeviceAdapterTag());
LowerBoundsKernel<decltype(inputPortal), decltype(valuesPortal), decltype(outputPortal)> kernel(
inputPortal, valuesPortal, outputPortal);
DerivedAlgorithm::Schedule(kernel, arraySize);
}
@ -276,18 +252,15 @@ public:
{
vtkm::Id arraySize = values.GetNumberOfValues();
LowerBoundsComparisonKernel<
typename vtkm::cont::ArrayHandle<T,
CIn>::template ExecutionTypes<DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<T, CVal>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<vtkm::Id,
COut>::template ExecutionTypes<DeviceAdapterTag>::Portal,
BinaryCompare>
kernel(input.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(arraySize, DeviceAdapterTag()),
binary_compare);
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
auto outputPortal = output.PrepareForOutput(arraySize, DeviceAdapterTag());
LowerBoundsComparisonKernel<decltype(inputPortal),
decltype(valuesPortal),
decltype(outputPortal),
BinaryCompare>
kernel(inputPortal, valuesPortal, outputPortal, binary_compare);
DerivedAlgorithm::Schedule(kernel, arraySize);
}
@ -322,23 +295,15 @@ public:
//
//Now that we have an implicit array that is 1/16 the length of full array
//we can use scan inclusive to compute the final sum
using InputPortalType =
typename vtkm::cont::ArrayHandle<T,
CIn>::template ExecutionTypes<DeviceAdapterTag>::PortalConst;
using ReduceKernelType = ReduceKernel<InputPortalType, U, BinaryFunctor>;
using ReduceHandleType = vtkm::cont::ArrayHandleImplicit<ReduceKernelType>;
using TempArrayType = vtkm::cont::ArrayHandle<U, vtkm::cont::StorageTagBasic>;
ReduceKernelType kernel(
input.PrepareForInput(DeviceAdapterTag()), initialValue, binary_functor);
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
ReduceKernel<decltype(inputPortal), U, BinaryFunctor> kernel(
inputPortal, initialValue, binary_functor);
vtkm::Id length = (input.GetNumberOfValues() / 16);
length += (input.GetNumberOfValues() % 16 == 0) ? 0 : 1;
ReduceHandleType reduced = vtkm::cont::make_ArrayHandleImplicit(kernel, length);
auto reduced = vtkm::cont::make_ArrayHandleImplicit(kernel, length);
TempArrayType inclusiveScanStorage;
vtkm::cont::ArrayHandle<U, vtkm::cont::StorageTagBasic> inclusiveScanStorage;
const U scanResult =
DerivedAlgorithm::ScanInclusive(reduced, inclusiveScanStorage, binary_functor);
return scanResult;
@ -372,9 +337,8 @@ public:
if (block == numBlocks - 1)
numberOfInstances = fullSize - blockSize * block;
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, CIn>> streamIn =
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, CIn>>(
input, block, blockSize, numberOfInstances);
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, CIn>> streamIn(
input, block, blockSize, numberOfInstances);
if (block == 0)
lastResult = DerivedAlgorithm::Reduce(streamIn, initialValue, binary_functor);
@ -415,17 +379,10 @@ public:
vtkm::cont::ArrayHandle<ReduceKeySeriesStates> keystate;
{
using InputPortalType = typename vtkm::cont::ArrayHandle<T, KIn>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst;
using KeyStatePortalType = typename vtkm::cont::ArrayHandle<
ReduceKeySeriesStates>::template ExecutionTypes<DeviceAdapterTag>::Portal;
InputPortalType inputPortal = keys.PrepareForInput(DeviceAdapterTag());
KeyStatePortalType keyStatePortal =
keystate.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ReduceStencilGeneration<InputPortalType, KeyStatePortalType> kernel(inputPortal,
keyStatePortal);
auto inputPortal = keys.PrepareForInput(DeviceAdapterTag());
auto keyStatePortal = keystate.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
inputPortal, keyStatePortal);
DerivedAlgorithm::Schedule(kernel, numberOfKeys);
}
@ -436,17 +393,11 @@ public:
// the value summed currently, the second being 0 or 1, with 1 being used
// when this is a value of a key we need to write ( END or START_AND_END)
{
using ValueInHandleType = vtkm::cont::ArrayHandle<U, VIn>;
using ValueOutHandleType = vtkm::cont::ArrayHandle<U, VOut>;
using StencilHandleType = vtkm::cont::ArrayHandle<ReduceKeySeriesStates>;
using ZipInHandleType = vtkm::cont::ArrayHandleZip<ValueInHandleType, StencilHandleType>;
using ZipOutHandleType = vtkm::cont::ArrayHandleZip<ValueOutHandleType, StencilHandleType>;
vtkm::cont::ArrayHandle<ReduceKeySeriesStates> stencil;
vtkm::cont::ArrayHandle<U, VOut> reducedValues;
StencilHandleType stencil;
ValueOutHandleType reducedValues;
ZipInHandleType scanInput(values, keystate);
ZipOutHandleType scanOutput(reducedValues, stencil);
auto scanInput = vtkm::cont::make_ArrayHandleZip(values, keystate);
auto scanOutput = vtkm::cont::make_ArrayHandleZip(reducedValues, stencil);
DerivedAlgorithm::ScanInclusive(
scanInput, scanOutput, ReduceByKeyAdd<BinaryFunctor>(binary_functor));
@ -474,28 +425,20 @@ public:
BinaryFunctor binaryFunctor,
const T& initialValue)
{
using TempArrayType = vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic>;
using OutputArrayType = vtkm::cont::ArrayHandle<T, COut>;
using SrcPortalType =
typename TempArrayType::template ExecutionTypes<DeviceAdapterTag>::PortalConst;
using DestPortalType =
typename OutputArrayType::template ExecutionTypes<DeviceAdapterTag>::Portal;
vtkm::Id numValues = input.GetNumberOfValues();
if (numValues <= 0)
{
return initialValue;
}
TempArrayType inclusiveScan;
vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic> inclusiveScan;
T result = DerivedAlgorithm::ScanInclusive(input, inclusiveScan, binaryFunctor);
InclusiveToExclusiveKernel<SrcPortalType, DestPortalType, BinaryFunctor> inclusiveToExclusive(
inclusiveScan.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(numValues, DeviceAdapterTag()),
binaryFunctor,
initialValue);
auto inputPortal = inclusiveScan.PrepareForInput(DeviceAdapterTag());
auto outputPortal = output.PrepareForOutput(numValues, DeviceAdapterTag());
InclusiveToExclusiveKernel<decltype(inputPortal), decltype(outputPortal), BinaryFunctor>
inclusiveToExclusive(inputPortal, outputPortal, binaryFunctor, initialValue);
DerivedAlgorithm::Schedule(inclusiveToExclusive, numValues);
@ -542,39 +485,22 @@ public:
vtkm::cont::ArrayHandle<ReduceKeySeriesStates> keystate;
{
using InputPortalType = typename vtkm::cont::ArrayHandle<T, KIn>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst;
using KeyStatePortalType = typename vtkm::cont::ArrayHandle<
ReduceKeySeriesStates>::template ExecutionTypes<DeviceAdapterTag>::Portal;
InputPortalType inputPortal = keys.PrepareForInput(DeviceAdapterTag());
KeyStatePortalType keyStatePortal =
keystate.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ReduceStencilGeneration<InputPortalType, KeyStatePortalType> kernel(inputPortal,
keyStatePortal);
auto inputPortal = keys.PrepareForInput(DeviceAdapterTag());
auto keyStatePortal = keystate.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
inputPortal, keyStatePortal);
DerivedAlgorithm::Schedule(kernel, numberOfKeys);
}
// 2. Shift input and initialize elements at head flags position to initValue
using TempArrayType = typename vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic>;
using TempPortalType =
typename vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic>::template ExecutionTypes<
DeviceAdapterTag>::Portal;
TempArrayType temp;
vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic> temp;
{
using InputPortalType = typename vtkm::cont::ArrayHandle<T, KIn>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst;
auto inputPortal = values.PrepareForInput(DeviceAdapterTag());
auto keyStatePortal = keystate.PrepareForInput(DeviceAdapterTag());
auto tempPortal = temp.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
using KeyStatePortalType = typename vtkm::cont::ArrayHandle<
ReduceKeySeriesStates>::template ExecutionTypes<DeviceAdapterTag>::PortalConst;
InputPortalType inputPortal = values.PrepareForInput(DeviceAdapterTag());
KeyStatePortalType keyStatePortal = keystate.PrepareForInput(DeviceAdapterTag());
TempPortalType tempPortal = temp.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ShiftCopyAndInit<U, InputPortalType, KeyStatePortalType, TempPortalType> kernel(
inputPortal, keyStatePortal, tempPortal, initialValue);
ShiftCopyAndInit<U, decltype(inputPortal), decltype(keyStatePortal), decltype(tempPortal)>
kernel(inputPortal, keyStatePortal, tempPortal, initialValue);
DerivedAlgorithm::Schedule(kernel, numberOfKeys);
}
// 3. Perform a ScanInclusiveByKey
@ -620,13 +546,11 @@ public:
if (block == numBlocks - 1)
numberOfInstances = fullSize - blockSize * block;
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, CIn>> streamIn =
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, CIn>>(
input, block, blockSize, numberOfInstances);
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, CIn>> streamIn(
input, block, blockSize, numberOfInstances);
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, COut>> streamOut =
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, COut>>(
output, block, blockSize, numberOfInstances);
vtkm::cont::ArrayHandleStreaming<vtkm::cont::ArrayHandle<T, COut>> streamOut(
output, block, blockSize, numberOfInstances);
if (block == 0)
{
@ -659,11 +583,6 @@ public:
vtkm::cont::ArrayHandle<T, COut>& output,
BinaryFunctor binary_functor)
{
using PortalType =
typename vtkm::cont::ArrayHandle<T, COut>::template ExecutionTypes<DeviceAdapterTag>::Portal;
using ScanKernelType = ScanKernel<PortalType, BinaryFunctor>;
DerivedAlgorithm::Copy(input, output);
vtkm::Id numValues = output.GetNumberOfValues();
@ -672,7 +591,9 @@ public:
return vtkm::TypeTraits<T>::ZeroInitialization();
}
PortalType portal = output.PrepareForInPlace(DeviceAdapterTag());
auto portal = output.PrepareForInPlace(DeviceAdapterTag());
using ScanKernelType = ScanKernel<decltype(portal), BinaryFunctor>;
vtkm::Id stride;
for (stride = 2; stride - 1 < numValues; stride *= 2)
@ -720,17 +641,10 @@ public:
vtkm::cont::ArrayHandle<ReduceKeySeriesStates> keystate;
{
using InputPortalType = typename vtkm::cont::ArrayHandle<T, KIn>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst;
using KeyStatePortalType = typename vtkm::cont::ArrayHandle<
ReduceKeySeriesStates>::template ExecutionTypes<DeviceAdapterTag>::Portal;
InputPortalType inputPortal = keys.PrepareForInput(DeviceAdapterTag());
KeyStatePortalType keyStatePortal =
keystate.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ReduceStencilGeneration<InputPortalType, KeyStatePortalType> kernel(inputPortal,
keyStatePortal);
auto inputPortal = keys.PrepareForInput(DeviceAdapterTag());
auto keyStatePortal = keystate.PrepareForOutput(numberOfKeys, DeviceAdapterTag());
ReduceStencilGeneration<decltype(inputPortal), decltype(keyStatePortal)> kernel(
inputPortal, keyStatePortal);
DerivedAlgorithm::Schedule(kernel, numberOfKeys);
}
@ -741,18 +655,10 @@ public:
// the value summed currently, the second being 0 or 1, with 1 being used
// when this is a value of a key we need to write ( END or START_AND_END)
{
using ValueInHandleType = vtkm::cont::ArrayHandle<U, VIn>;
using ValueOutHandleType = vtkm::cont::ArrayHandle<U, VOut>;
using StencilHandleType = vtkm::cont::ArrayHandle<ReduceKeySeriesStates>;
using ZipInHandleType = vtkm::cont::ArrayHandleZip<ValueInHandleType, StencilHandleType>;
using ZipOutHandleType = vtkm::cont::ArrayHandleZip<ValueOutHandleType, StencilHandleType>;
StencilHandleType stencil;
ValueOutHandleType reducedValues;
ZipInHandleType scanInput(values, keystate);
ZipOutHandleType scanOutput(reducedValues, stencil);
vtkm::cont::ArrayHandle<U, VOut> reducedValues;
vtkm::cont::ArrayHandle<ReduceKeySeriesStates> stencil;
auto scanInput = vtkm::cont::make_ArrayHandleZip(values, keystate);
auto scanOutput = vtkm::cont::make_ArrayHandleZip(reducedValues, stencil);
DerivedAlgorithm::ScanInclusive(
scanInput, scanOutput, ReduceByKeyAdd<BinaryFunctor>(binary_functor));
@ -768,17 +674,11 @@ public:
VTKM_CONT static void Sort(vtkm::cont::ArrayHandle<T, Storage>& values,
BinaryCompare binary_compare)
{
using ArrayType = typename vtkm::cont::ArrayHandle<T, Storage>;
using PortalType = typename ArrayType::template ExecutionTypes<DeviceAdapterTag>::Portal;
vtkm::Id numValues = values.GetNumberOfValues();
if (numValues < 2)
{
return;
}
PortalType portal = values.PrepareForInPlace(DeviceAdapterTag());
vtkm::Id numThreads = 1;
while (numThreads < numValues)
{
@ -786,8 +686,9 @@ public:
}
numThreads /= 2;
using MergeKernel = BitonicSortMergeKernel<PortalType, BinaryCompare>;
using CrossoverKernel = BitonicSortCrossoverKernel<PortalType, BinaryCompare>;
auto portal = values.PrepareForInPlace(DeviceAdapterTag());
using MergeKernel = BitonicSortMergeKernel<decltype(portal), BinaryCompare>;
using CrossoverKernel = BitonicSortCrossoverKernel<decltype(portal), BinaryCompare>;
for (vtkm::Id crossoverSize = 1; crossoverSize < numValues; crossoverSize *= 2)
{
@ -816,12 +717,7 @@ public:
//combine the keys and values into a ZipArrayHandle
//we than need to specify a custom compare function wrapper
//that only checks for key side of the pair, using a custom compare functor.
using KeyType = vtkm::cont::ArrayHandle<T, StorageT>;
;
using ValueType = vtkm::cont::ArrayHandle<U, StorageU>;
using ZipHandleType = vtkm::cont::ArrayHandleZip<KeyType, ValueType>;
ZipHandleType zipHandle = vtkm::cont::make_ArrayHandleZip(keys, values);
auto zipHandle = vtkm::cont::make_ArrayHandleZip(keys, values);
DerivedAlgorithm::Sort(zipHandle, internal::KeyCompare<T, U>());
}
@ -834,12 +730,7 @@ public:
//we than need to specify a custom compare function wrapper
//that only checks for key side of the pair, using the custom compare
//functor that the user passed in
using KeyType = vtkm::cont::ArrayHandle<T, StorageT>;
;
using ValueType = vtkm::cont::ArrayHandle<U, StorageU>;
using ZipHandleType = vtkm::cont::ArrayHandleZip<KeyType, ValueType>;
ZipHandleType zipHandle = vtkm::cont::make_ArrayHandleZip(keys, values);
auto zipHandle = vtkm::cont::make_ArrayHandleZip(keys, values);
DerivedAlgorithm::Sort(zipHandle, internal::KeyCompare<T, U, BinaryCompare>(binary_compare));
}
@ -861,15 +752,10 @@ public:
using WrappedBOpType = internal::WrappedBinaryOperator<bool, BinaryCompare>;
WrappedBOpType wrappedCompare(binary_compare);
ClassifyUniqueComparisonKernel<
typename vtkm::cont::ArrayHandle<T, Storage>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<vtkm::Id, vtkm::cont::StorageTagBasic>::
template ExecutionTypes<DeviceAdapterTag>::Portal,
WrappedBOpType>
classifyKernel(values.PrepareForInput(DeviceAdapterTag()),
stencilArray.PrepareForOutput(inputSize, DeviceAdapterTag()),
wrappedCompare);
auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
auto stencilPortal = stencilArray.PrepareForOutput(inputSize, DeviceAdapterTag());
ClassifyUniqueComparisonKernel<decltype(valuesPortal), decltype(stencilPortal), WrappedBOpType>
classifyKernel(valuesPortal, stencilPortal, wrappedCompare);
DerivedAlgorithm::Schedule(classifyKernel, inputSize);
@ -890,16 +776,12 @@ public:
{
vtkm::Id arraySize = values.GetNumberOfValues();
UpperBoundsKernel<typename vtkm::cont::ArrayHandle<T, CIn>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<T, CVal>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<vtkm::Id, COut>::template ExecutionTypes<
DeviceAdapterTag>::Portal>
kernel(input.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(arraySize, DeviceAdapterTag()));
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
auto outputPortal = output.PrepareForOutput(arraySize, DeviceAdapterTag());
UpperBoundsKernel<decltype(inputPortal), decltype(valuesPortal), decltype(outputPortal)> kernel(
inputPortal, valuesPortal, outputPortal);
DerivedAlgorithm::Schedule(kernel, arraySize);
}
@ -911,18 +793,15 @@ public:
{
vtkm::Id arraySize = values.GetNumberOfValues();
UpperBoundsKernelComparisonKernel<
typename vtkm::cont::ArrayHandle<T,
CIn>::template ExecutionTypes<DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<T, CVal>::template ExecutionTypes<
DeviceAdapterTag>::PortalConst,
typename vtkm::cont::ArrayHandle<vtkm::Id,
COut>::template ExecutionTypes<DeviceAdapterTag>::Portal,
BinaryCompare>
kernel(input.PrepareForInput(DeviceAdapterTag()),
values.PrepareForInput(DeviceAdapterTag()),
output.PrepareForOutput(arraySize, DeviceAdapterTag()),
binary_compare);
auto inputPortal = input.PrepareForInput(DeviceAdapterTag());
auto valuesPortal = values.PrepareForInput(DeviceAdapterTag());
auto outputPortal = output.PrepareForOutput(arraySize, DeviceAdapterTag());
UpperBoundsKernelComparisonKernel<decltype(inputPortal),
decltype(valuesPortal),
decltype(outputPortal),
BinaryCompare>
kernel(inputPortal, valuesPortal, outputPortal, binary_compare);
DerivedAlgorithm::Schedule(kernel, arraySize);
}

@ -396,7 +396,7 @@ struct LowerBoundsKernel
using InputIteratorsType = vtkm::cont::ArrayPortalToIterators<InputPortalType>;
InputIteratorsType inputIterators(this->InputPortal);
typename InputIteratorsType::IteratorType resultPos = std::lower_bound(
auto resultPos = std::lower_bound(
inputIterators.GetBegin(), inputIterators.GetEnd(), this->ValuesPortal.Get(index));
vtkm::Id resultIndex =
@ -444,11 +444,10 @@ struct LowerBoundsComparisonKernel
using InputIteratorsType = vtkm::cont::ArrayPortalToIterators<InputPortalType>;
InputIteratorsType inputIterators(this->InputPortal);
typename InputIteratorsType::IteratorType resultPos =
std::lower_bound(inputIterators.GetBegin(),
inputIterators.GetEnd(),
this->ValuesPortal.Get(index),
this->CompareFunctor);
auto resultPos = std::lower_bound(inputIterators.GetBegin(),
inputIterators.GetEnd(),
this->ValuesPortal.Get(index),
this->CompareFunctor);
vtkm::Id resultIndex =
static_cast<vtkm::Id>(std::distance(inputIterators.GetBegin(), resultPos));
@ -748,7 +747,7 @@ struct UpperBoundsKernel
using InputIteratorsType = vtkm::cont::ArrayPortalToIterators<InputPortalType>;
InputIteratorsType inputIterators(this->InputPortal);
typename InputIteratorsType::IteratorType resultPos = std::upper_bound(
auto resultPos = std::upper_bound(
inputIterators.GetBegin(), inputIterators.GetEnd(), this->ValuesPortal.Get(index));
vtkm::Id resultIndex =
@ -796,11 +795,10 @@ struct UpperBoundsKernelComparisonKernel
using InputIteratorsType = vtkm::cont::ArrayPortalToIterators<InputPortalType>;
InputIteratorsType inputIterators(this->InputPortal);
typename InputIteratorsType::IteratorType resultPos =
std::upper_bound(inputIterators.GetBegin(),
inputIterators.GetEnd(),
this->ValuesPortal.Get(index),
this->CompareFunctor);
auto resultPos = std::upper_bound(inputIterators.GetBegin(),
inputIterators.GetEnd(),
this->ValuesPortal.Get(index),
this->CompareFunctor);
vtkm::Id resultIndex =
static_cast<vtkm::Id>(std::distance(inputIterators.GetBegin(), resultPos));

@ -61,11 +61,8 @@ public:
U initialValue,
BinaryFunctor binary_functor)
{
typedef typename vtkm::cont::ArrayHandle<T, CIn>::template ExecutionTypes<Device>::PortalConst
PortalIn;
internal::WrappedBinaryOperator<U, BinaryFunctor> wrappedOp(binary_functor);
PortalIn inputPortal = input.PrepareForInput(Device());
auto inputPortal = input.PrepareForInput(Device());
return std::accumulate(vtkm::cont::ArrayPortalToIteratorBegin(inputPortal),
vtkm::cont::ArrayPortalToIteratorEnd(inputPortal),
initialValue,
@ -85,22 +82,12 @@ public:
vtkm::cont::ArrayHandle<U, VOut>& values_output,
BinaryFunctor binary_functor)
{
typedef typename vtkm::cont::ArrayHandle<T, KIn>::template ExecutionTypes<Device>::PortalConst
PortalKIn;
typedef typename vtkm::cont::ArrayHandle<U, VIn>::template ExecutionTypes<Device>::PortalConst
PortalVIn;
typedef
typename vtkm::cont::ArrayHandle<T, KOut>::template ExecutionTypes<Device>::Portal PortalKOut;
typedef
typename vtkm::cont::ArrayHandle<U, VOut>::template ExecutionTypes<Device>::Portal PortalVOut;
PortalKIn keysPortalIn = keys.PrepareForInput(Device());
PortalVIn valuesPortalIn = values.PrepareForInput(Device());
auto keysPortalIn = keys.PrepareForInput(Device());
auto valuesPortalIn = values.PrepareForInput(Device());
const vtkm::Id numberOfKeys = keys.GetNumberOfValues();
PortalKOut keysPortalOut = keys_output.PrepareForOutput(numberOfKeys, Device());
PortalVOut valuesPortalOut = values_output.PrepareForOutput(numberOfKeys, Device());
auto keysPortalOut = keys_output.PrepareForOutput(numberOfKeys, Device());
auto valuesPortalOut = values_output.PrepareForOutput(numberOfKeys, Device());
vtkm::Id writePos = 0;
vtkm::Id readPos = 0;
@ -141,15 +128,10 @@ public:
VTKM_CONT static T ScanInclusive(const vtkm::cont::ArrayHandle<T, CIn>& input,
vtkm::cont::ArrayHandle<T, COut>& output)
{
typedef
typename vtkm::cont::ArrayHandle<T, COut>::template ExecutionTypes<Device>::Portal PortalOut;
typedef typename vtkm::cont::ArrayHandle<T, CIn>::template ExecutionTypes<Device>::PortalConst
PortalIn;
vtkm::Id numberOfValues = input.GetNumberOfValues();
PortalIn inputPortal = input.PrepareForInput(Device());
PortalOut outputPortal = output.PrepareForOutput(numberOfValues, Device());
auto inputPortal = input.PrepareForInput(Device());
auto outputPortal = output.PrepareForOutput(numberOfValues, Device());
if (numberOfValues <= 0)
{
@ -169,17 +151,12 @@ public:
vtkm::cont::ArrayHandle<T, COut>& output,
BinaryFunctor binary_functor)
{
typedef
typename vtkm::cont::ArrayHandle<T, COut>::template ExecutionTypes<Device>::Portal PortalOut;
typedef typename vtkm::cont::ArrayHandle<T, CIn>::template ExecutionTypes<Device>::PortalConst
PortalIn;
internal::WrappedBinaryOperator<T, BinaryFunctor> wrappedBinaryOp(binary_functor);
vtkm::Id numberOfValues = input.GetNumberOfValues();
PortalIn inputPortal = input.PrepareForInput(Device());
PortalOut outputPortal = output.PrepareForOutput(numberOfValues, Device());
auto inputPortal = input.PrepareForInput(Device());
auto outputPortal = output.PrepareForOutput(numberOfValues, Device());
if (numberOfValues <= 0)
{
@ -201,17 +178,12 @@ public:
BinaryFunctor binaryFunctor,
const T& initialValue)
{
typedef
typename vtkm::cont::ArrayHandle<T, COut>::template ExecutionTypes<Device>::Portal PortalOut;
typedef typename vtkm::cont::ArrayHandle<T, CIn>::template ExecutionTypes<Device>::PortalConst
PortalIn;
internal::WrappedBinaryOperator<T, BinaryFunctor> wrappedBinaryOp(binaryFunctor);
vtkm::Id numberOfValues = input.GetNumberOfValues();
PortalIn inputPortal = input.PrepareForInput(Device());
PortalOut outputPortal = output.PrepareForOutput(numberOfValues, Device());
auto inputPortal = input.PrepareForInput(Device());
auto outputPortal = output.PrepareForOutput(numberOfValues, Device());
if (numberOfValues <= 0)
{
@ -277,21 +249,12 @@ private:
vtkm::cont::ArrayHandle<I, StorageI>& index,
vtkm::cont::ArrayHandle<Vout, StorageVout>& values_out)
{
typedef typename vtkm::cont::ArrayHandle<Vin, StorageVin>::template ExecutionTypes<
Device>::PortalConst PortalVIn;
typedef
typename vtkm::cont::ArrayHandle<I, StorageI>::template ExecutionTypes<Device>::PortalConst
PortalI;
typedef
typename vtkm::cont::ArrayHandle<Vout, StorageVout>::template ExecutionTypes<Device>::Portal
PortalVout;
const vtkm::Id n = values.GetNumberOfValues();
VTKM_ASSERT(n == index.GetNumberOfValues());
PortalVIn valuesPortal = values.PrepareForInput(Device());
PortalI indexPortal = index.PrepareForInput(Device());
PortalVout valuesOutPortal = values_out.PrepareForOutput(n, Device());
auto valuesPortal = values.PrepareForInput(Device());
auto indexPortal = index.PrepareForInput(Device());
auto valuesOutPortal = values_out.PrepareForOutput(n, Device());
for (vtkm::Id i = 0; i < n; i++)
{
@ -310,12 +273,7 @@ private:
//we than need to specify a custom compare function wrapper
//that only checks for key side of the pair, using the custom compare
//functor that the user passed in
using KeyType = vtkm::cont::ArrayHandle<T, StorageT>;
;
using ValueType = vtkm::cont::ArrayHandle<U, StorageU>;
typedef vtkm::cont::ArrayHandleZip<KeyType, ValueType> ZipHandleType;
ZipHandleType zipHandle = vtkm::cont::make_ArrayHandleZip(keys, values);
auto zipHandle = vtkm::cont::make_ArrayHandleZip(keys, values);
Sort(zipHandle, internal::KeyCompare<T, U, BinaryCompare>(binary_compare));
}
@ -337,11 +295,8 @@ public:
{
/// More efficient sort:
/// Move value indexes when sorting and reorder the value array at last
using ValueType = vtkm::cont::ArrayHandle<U, StorageU>;
using IndexType = vtkm::cont::ArrayHandle<vtkm::Id>;
IndexType indexArray;
ValueType valuesScattered;
vtkm::cont::ArrayHandle<vtkm::Id> indexArray;
vtkm::cont::ArrayHandle<U, StorageU> valuesScattered;
Copy(ArrayHandleIndex(keys.GetNumberOfValues()), indexArray);
SortByKeyDirect(keys, indexArray, wrappedCompare);
@ -364,11 +319,8 @@ public:
VTKM_CONT static void Sort(vtkm::cont::ArrayHandle<T, Storage>& values,
BinaryCompare binary_compare)
{
typedef typename vtkm::cont::ArrayHandle<T, Storage>::template ExecutionTypes<Device>::Portal
PortalType;
PortalType arrayPortal = values.PrepareForInPlace(Device());
vtkm::cont::ArrayPortalToIterators<PortalType> iterators(arrayPortal);
auto arrayPortal = values.PrepareForInPlace(Device());
vtkm::cont::ArrayPortalToIterators<decltype(arrayPortal)> iterators(arrayPortal);
internal::WrappedBinaryOperator<bool, BinaryCompare> wrappedCompare(binary_compare);
std::sort(iterators.GetBegin(), iterators.GetEnd(), wrappedCompare);