Merge topic '173_tbb_reducebykey'
d174c0fe Add TBB specialization for ReduceByKey. c9c7149c Fix typo in ReduceByKey docstring. Acked-by: Kitware Robot <kwrobot@kitware.com> Acked-by: Kenneth Moreland <kmorel@sandia.gov> Merge-request: !928
This commit is contained in:
commit
6c7c4ddec0
@ -203,7 +203,7 @@ struct DeviceAdapterAlgorithm
|
||||
VTKM_CONT static void ReduceByKey(const vtkm::cont::ArrayHandle<T, CKeyIn>& keys,
|
||||
const vtkm::cont::ArrayHandle<U, CValIn>& values,
|
||||
vtkm::cont::ArrayHandle<T, CKeyOut>& keys_output,
|
||||
vtkm::cont::ArrayHandle<T, CValOut>& values_output,
|
||||
vtkm::cont::ArrayHandle<U, CValOut>& values_output,
|
||||
BinaryFunctor binary_functor);
|
||||
|
||||
/// \brief Compute an inclusive prefix sum operation on the input ArrayHandle.
|
||||
|
@ -60,6 +60,31 @@ public:
|
||||
input.PrepareForInput(vtkm::cont::DeviceAdapterTagTBB()), initialValue, binary_functor);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename U,
|
||||
class CKeyIn,
|
||||
class CValIn,
|
||||
class CKeyOut,
|
||||
class CValOut,
|
||||
class BinaryFunctor>
|
||||
VTKM_CONT static void ReduceByKey(const vtkm::cont::ArrayHandle<T, CKeyIn>& keys,
|
||||
const vtkm::cont::ArrayHandle<U, CValIn>& values,
|
||||
vtkm::cont::ArrayHandle<T, CKeyOut>& keys_output,
|
||||
vtkm::cont::ArrayHandle<U, CValOut>& values_output,
|
||||
BinaryFunctor binary_functor)
|
||||
{
|
||||
vtkm::Id inputSize = keys.GetNumberOfValues();
|
||||
VTKM_ASSERT(inputSize == values.GetNumberOfValues());
|
||||
vtkm::Id outputSize =
|
||||
tbb::ReduceByKeyPortals(keys.PrepareForInput(DeviceAdapterTagTBB()),
|
||||
values.PrepareForInput(DeviceAdapterTagTBB()),
|
||||
keys_output.PrepareForOutput(inputSize, DeviceAdapterTagTBB()),
|
||||
values_output.PrepareForOutput(inputSize, DeviceAdapterTagTBB()),
|
||||
binary_functor);
|
||||
keys_output.Shrink(outputSize);
|
||||
values_output.Shrink(outputSize);
|
||||
}
|
||||
|
||||
template <typename T, class CIn, class COut>
|
||||
VTKM_CONT static T ScanInclusive(const vtkm::cont::ArrayHandle<T, CIn>& input,
|
||||
vtkm::cont::ArrayHandle<T, COut>& output)
|
||||
|
@ -27,6 +27,9 @@
|
||||
#include <vtkm/cont/internal/FunctorsGeneral.h>
|
||||
#include <vtkm/exec/internal/ErrorMessageBuffer.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
|
||||
VTKM_THIRDPARTY_PRE_INCLUDE
|
||||
|
||||
#if defined(VTKM_MSVC)
|
||||
@ -195,6 +198,359 @@ VTKM_CONT static T ReducePortals(InputPortalType inputPortal,
|
||||
}
|
||||
}
|
||||
|
||||
// Define this to print out timing information from the reduction and join
|
||||
// operations in the tbb ReduceByKey algorithm:
|
||||
//#define _VTKM_DEBUG_TBB_RBK
|
||||
|
||||
template <typename KeysInPortalType,
|
||||
typename ValuesInPortalType,
|
||||
typename KeysOutPortalType,
|
||||
typename ValuesOutPortalType,
|
||||
class BinaryOperationType>
|
||||
struct ReduceByKeyBody
|
||||
{
|
||||
using KeyType = typename KeysInPortalType::ValueType;
|
||||
using ValueType = typename ValuesInPortalType::ValueType;
|
||||
|
||||
struct Range
|
||||
{
|
||||
vtkm::Id InputBegin;
|
||||
vtkm::Id InputEnd;
|
||||
vtkm::Id OutputBegin;
|
||||
vtkm::Id OutputEnd;
|
||||
|
||||
VTKM_EXEC_CONT
|
||||
Range()
|
||||
: InputBegin(-1)
|
||||
, InputEnd(-1)
|
||||
, OutputBegin(-1)
|
||||
, OutputEnd(-1)
|
||||
{
|
||||
}
|
||||
|
||||
VTKM_EXEC_CONT
|
||||
Range(vtkm::Id inputBegin, vtkm::Id inputEnd, vtkm::Id outputBegin, vtkm::Id outputEnd)
|
||||
: InputBegin(inputBegin)
|
||||
, InputEnd(inputEnd)
|
||||
, OutputBegin(outputBegin)
|
||||
, OutputEnd(outputEnd)
|
||||
{
|
||||
this->AssertSane();
|
||||
}
|
||||
|
||||
VTKM_EXEC_CONT
|
||||
void AssertSane() const
|
||||
{
|
||||
VTKM_ASSERT("Input begin precedes end" && this->InputBegin <= this->InputEnd);
|
||||
VTKM_ASSERT("Output begin precedes end" && this->OutputBegin <= this->OutputEnd);
|
||||
VTKM_ASSERT("Output not past input" && this->OutputBegin <= this->InputBegin &&
|
||||
this->OutputEnd <= this->InputEnd);
|
||||
VTKM_ASSERT("Output smaller than input" &&
|
||||
(this->OutputEnd - this->OutputBegin) <= (this->InputEnd - this->InputBegin));
|
||||
}
|
||||
|
||||
VTKM_EXEC_CONT
|
||||
bool IsNext(const Range& next) const { return this->InputEnd == next.InputBegin; }
|
||||
};
|
||||
|
||||
KeysInPortalType KeysInPortal;
|
||||
ValuesInPortalType ValuesInPortal;
|
||||
KeysOutPortalType KeysOutPortal;
|
||||
ValuesOutPortalType ValuesOutPortal;
|
||||
BinaryOperationType BinaryOperation;
|
||||
Range Ranges;
|
||||
#ifdef _VTKM_DEBUG_TBB_RBK
|
||||
double ReduceTime;
|
||||
double JoinTime;
|
||||
#endif
|
||||
|
||||
VTKM_CONT
|
||||
ReduceByKeyBody(const KeysInPortalType& keysInPortal,
|
||||
const ValuesInPortalType& valuesInPortal,
|
||||
const KeysOutPortalType& keysOutPortal,
|
||||
const ValuesOutPortalType& valuesOutPortal,
|
||||
BinaryOperationType binaryOperation)
|
||||
: KeysInPortal(keysInPortal)
|
||||
, ValuesInPortal(valuesInPortal)
|
||||
, KeysOutPortal(keysOutPortal)
|
||||
, ValuesOutPortal(valuesOutPortal)
|
||||
, BinaryOperation(binaryOperation)
|
||||
#ifdef _VTKM_DEBUG_TBB_RBK
|
||||
, ReduceTime(0)
|
||||
, JoinTime(0)
|
||||
#endif
|
||||
{
|
||||
}
|
||||
|
||||
VTKM_EXEC_CONT
|
||||
ReduceByKeyBody(const ReduceByKeyBody& body, ::tbb::split)
|
||||
: KeysInPortal(body.KeysInPortal)
|
||||
, ValuesInPortal(body.ValuesInPortal)
|
||||
, KeysOutPortal(body.KeysOutPortal)
|
||||
, ValuesOutPortal(body.ValuesOutPortal)
|
||||
, BinaryOperation(body.BinaryOperation)
|
||||
#ifdef _VTKM_DEBUG_TBB_RBK
|
||||
, ReduceTime(0)
|
||||
, JoinTime(0)
|
||||
#endif
|
||||
{
|
||||
}
|
||||
|
||||
VTKM_SUPPRESS_EXEC_WARNINGS
|
||||
VTKM_EXEC
|
||||
void operator()(const ::tbb::blocked_range<vtkm::Id>& range)
|
||||
{
|
||||
#ifdef _VTKM_DEBUG_TBB_RBK
|
||||
::tbb::tick_count startTime = ::tbb::tick_count::now();
|
||||
#endif // _VTKM_DEBUG_TBB_RBK
|
||||
if (range.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
bool firstRun = this->Ranges.OutputBegin < 0; // First use of this body object
|
||||
if (firstRun)
|
||||
{
|
||||
this->Ranges.InputBegin = range.begin();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Must be a continuation of the previous input range:
|
||||
VTKM_ASSERT(this->Ranges.InputEnd == range.begin());
|
||||
}
|
||||
this->Ranges.InputEnd = range.end();
|
||||
this->Ranges.AssertSane();
|
||||
|
||||
using KeysInIteratorsType = vtkm::cont::ArrayPortalToIterators<KeysInPortalType>;
|
||||
using ValuesInIteratorsType = vtkm::cont::ArrayPortalToIterators<ValuesInPortalType>;
|
||||
using KeysOutIteratorsType = vtkm::cont::ArrayPortalToIterators<KeysOutPortalType>;
|
||||
using ValuesOutIteratorsType = vtkm::cont::ArrayPortalToIterators<ValuesOutPortalType>;
|
||||
|
||||
KeysInIteratorsType keysInIters(this->KeysInPortal);
|
||||
ValuesInIteratorsType valuesInIters(this->ValuesInPortal);
|
||||
KeysOutIteratorsType keysOutIters(this->KeysOutPortal);
|
||||
ValuesOutIteratorsType valuesOutIters(this->ValuesOutPortal);
|
||||
|
||||
using KeysInIteratorType = typename KeysInIteratorsType::IteratorType;
|
||||
using ValuesInIteratorType = typename ValuesInIteratorsType::IteratorType;
|
||||
using KeysOutIteratorType = typename KeysOutIteratorsType::IteratorType;
|
||||
using ValuesOutIteratorType = typename ValuesOutIteratorsType::IteratorType;
|
||||
|
||||
KeysInIteratorType keysIn = keysInIters.GetBegin();
|
||||
ValuesInIteratorType valuesIn = valuesInIters.GetBegin();
|
||||
KeysOutIteratorType keysOut = keysOutIters.GetBegin();
|
||||
ValuesOutIteratorType valuesOut = valuesOutIters.GetBegin();
|
||||
|
||||
vtkm::Id readPos = range.begin();
|
||||
const vtkm::Id readEnd = range.end();
|
||||
|
||||
// Determine output index. If we're reusing the body, pick up where the
|
||||
// last block left off. If not, use the input range.
|
||||
vtkm::Id writePos;
|
||||
if (firstRun)
|
||||
{
|
||||
this->Ranges.OutputBegin = range.begin();
|
||||
this->Ranges.OutputEnd = range.begin();
|
||||
writePos = range.begin();
|
||||
}
|
||||
else
|
||||
{
|
||||
writePos = this->Ranges.OutputEnd;
|
||||
}
|
||||
this->Ranges.AssertSane();
|
||||
|
||||
// We're either writing at the end of a previous block, or at the input
|
||||
// location. Either way, the write position will never be greater than
|
||||
// the read position.
|
||||
VTKM_ASSERT(writePos <= readPos);
|
||||
|
||||
// Initialize reduction variables:
|
||||
BinaryOperationType functor(this->BinaryOperation);
|
||||
KeyType currentKey = keysIn[readPos];
|
||||
ValueType currentValue = valuesIn[readPos];
|
||||
++readPos;
|
||||
|
||||
// If the start of the current range continues a previous key block,
|
||||
// initialize with the previous result and decrement the write index.
|
||||
VTKM_ASSERT(firstRun || writePos > 0);
|
||||
if (!firstRun && keysOut[writePos - 1] == currentKey)
|
||||
{
|
||||
// Ensure that we'll overwrite the continued key values:
|
||||
--writePos;
|
||||
|
||||
// Update our accumulator with the partial value:
|
||||
currentValue = functor(valuesOut[writePos], currentValue);
|
||||
}
|
||||
|
||||
// Special case: single value in range
|
||||
if (readPos >= readEnd)
|
||||
{
|
||||
keysOut[writePos] = currentKey;
|
||||
valuesOut[writePos] = currentValue;
|
||||
++writePos;
|
||||
|
||||
this->Ranges.OutputEnd = writePos;
|
||||
return;
|
||||
}
|
||||
|
||||
for (;;)
|
||||
{
|
||||
while (readPos < readEnd && currentKey == keysIn[readPos])
|
||||
{
|
||||
currentValue = functor(currentValue, valuesIn[readPos]);
|
||||
++readPos;
|
||||
}
|
||||
|
||||
VTKM_ASSERT(writePos <= readPos);
|
||||
keysOut[writePos] = currentKey;
|
||||
valuesOut[writePos] = currentValue;
|
||||
++writePos;
|
||||
|
||||
if (readPos < readEnd)
|
||||
{
|
||||
currentKey = keysIn[readPos];
|
||||
currentValue = valuesIn[readPos];
|
||||
++readPos;
|
||||
continue;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
this->Ranges.OutputEnd = writePos;
|
||||
|
||||
#ifdef _VTKM_DEBUG_TBB_RBK
|
||||
::tbb::tick_count endTime = ::tbb::tick_count::now();
|
||||
double time = (endTime - startTime).seconds();
|
||||
this->ReduceTime += time;
|
||||
std::ostringstream out;
|
||||
out << "Reduced " << range.size() << " key/value pairs in " << time << "s. "
|
||||
<< "InRange: " << this->Ranges.InputBegin << " " << this->Ranges.InputEnd << " "
|
||||
<< "OutRange: " << this->Ranges.OutputBegin << " " << this->Ranges.OutputEnd << "\n";
|
||||
std::cerr << out.str();
|
||||
#endif
|
||||
}
|
||||
|
||||
VTKM_SUPPRESS_EXEC_WARNINGS
|
||||
VTKM_EXEC_CONT
|
||||
void join(const ReduceByKeyBody& rhs)
|
||||
{
|
||||
using KeysIteratorsType = vtkm::cont::ArrayPortalToIterators<KeysOutPortalType>;
|
||||
using ValuesIteratorsType = vtkm::cont::ArrayPortalToIterators<ValuesOutPortalType>;
|
||||
using KeysIteratorType = typename KeysIteratorsType::IteratorType;
|
||||
using ValuesIteratorType = typename ValuesIteratorsType::IteratorType;
|
||||
|
||||
#ifdef _VTKM_DEBUG_TBB_RBK
|
||||
::tbb::tick_count startTime = ::tbb::tick_count::now();
|
||||
#endif
|
||||
|
||||
this->Ranges.AssertSane();
|
||||
rhs.Ranges.AssertSane();
|
||||
// Ensure that we're joining two consecutive subsets of the input:
|
||||
VTKM_ASSERT(this->Ranges.IsNext(rhs.Ranges));
|
||||
|
||||
KeysIteratorsType keysIters(this->KeysOutPortal);
|
||||
ValuesIteratorsType valuesIters(this->ValuesOutPortal);
|
||||
KeysIteratorType keys = keysIters.GetBegin();
|
||||
ValuesIteratorType values = valuesIters.GetBegin();
|
||||
|
||||
const vtkm::Id dstBegin = this->Ranges.OutputEnd;
|
||||
const vtkm::Id lastDstIdx = this->Ranges.OutputEnd - 1;
|
||||
|
||||
vtkm::Id srcBegin = rhs.Ranges.OutputBegin;
|
||||
const vtkm::Id srcEnd = rhs.Ranges.OutputEnd;
|
||||
|
||||
// Merge boundaries if needed:
|
||||
if (keys[srcBegin] == keys[lastDstIdx])
|
||||
{
|
||||
values[lastDstIdx] = this->BinaryOperation(values[lastDstIdx], values[srcBegin]);
|
||||
++srcBegin; // Don't copy the key/value we just reduced
|
||||
}
|
||||
|
||||
// move data:
|
||||
if (srcBegin != dstBegin && srcBegin != srcEnd)
|
||||
{
|
||||
// Sanity check:
|
||||
VTKM_ASSERT(srcBegin < srcEnd);
|
||||
// Not necessary for the copy call to be safe, but if the src range
|
||||
// overlaps with the dst range there's a problem with the algorithm:
|
||||
VTKM_ASSERT(dstBegin + (srcEnd - srcBegin) <= srcBegin);
|
||||
std::copy(keys + srcBegin, keys + srcEnd, keys + dstBegin);
|
||||
std::copy(values + srcBegin, values + srcEnd, values + dstBegin);
|
||||
}
|
||||
|
||||
this->Ranges.InputEnd = rhs.Ranges.InputEnd;
|
||||
this->Ranges.OutputEnd += srcEnd - srcBegin;
|
||||
this->Ranges.AssertSane();
|
||||
|
||||
#ifdef _VTKM_DEBUG_TBB_RBK
|
||||
::tbb::tick_count endTime = ::tbb::tick_count::now();
|
||||
double time = (endTime - startTime).seconds();
|
||||
this->JoinTime += rhs.JoinTime + time;
|
||||
std::ostringstream out;
|
||||
out << "Joined " << (srcEnd - srcBegin) << " rhs values into body in " << time << "s. "
|
||||
<< "InRange: " << this->Ranges.InputBegin << " " << this->Ranges.InputEnd << " "
|
||||
<< "OutRange: " << this->Ranges.OutputBegin << " " << this->Ranges.OutputEnd << "\n";
|
||||
std::cerr << out.str();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
VTKM_SUPPRESS_EXEC_WARNINGS
|
||||
template <typename KeysInPortalType,
|
||||
typename ValuesInPortalType,
|
||||
typename KeysOutPortalType,
|
||||
typename ValuesOutPortalType,
|
||||
typename BinaryOperationType>
|
||||
VTKM_CONT vtkm::Id ReduceByKeyPortals(KeysInPortalType keysInPortal,
|
||||
ValuesInPortalType valuesInPortal,
|
||||
KeysOutPortalType keysOutPortal,
|
||||
ValuesOutPortalType valuesOutPortal,
|
||||
BinaryOperationType binaryOperation)
|
||||
{
|
||||
const vtkm::Id inputLength = keysInPortal.GetNumberOfValues();
|
||||
VTKM_ASSERT(inputLength == valuesInPortal.GetNumberOfValues());
|
||||
|
||||
if (inputLength == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
using ValueType = typename ValuesInPortalType::ValueType;
|
||||
using WrappedBinaryOp = internal::WrappedBinaryOperator<ValueType, BinaryOperationType>;
|
||||
WrappedBinaryOp wrappedBinaryOp(binaryOperation);
|
||||
|
||||
ReduceByKeyBody<KeysInPortalType,
|
||||
ValuesInPortalType,
|
||||
KeysOutPortalType,
|
||||
ValuesOutPortalType,
|
||||
WrappedBinaryOp>
|
||||
body(keysInPortal, valuesInPortal, keysOutPortal, valuesOutPortal, wrappedBinaryOp);
|
||||
::tbb::blocked_range<vtkm::Id> range(0, inputLength, TBB_GRAIN_SIZE);
|
||||
|
||||
#ifdef _VTKM_DEBUG_TBB_RBK
|
||||
std::cerr << "\n\nTBB ReduceByKey:\n";
|
||||
#endif
|
||||
|
||||
::tbb::parallel_reduce(range, body);
|
||||
|
||||
#ifdef _VTKM_DEBUG_TBB_RBK
|
||||
std::cerr << "Total reduce time: " << body.ReduceTime << "s\n";
|
||||
std::cerr << "Total join time: " << body.JoinTime << "s\n";
|
||||
std::cerr << "\nend\n";
|
||||
#endif
|
||||
|
||||
body.Ranges.AssertSane();
|
||||
VTKM_ASSERT(body.Ranges.InputBegin == 0 && body.Ranges.InputEnd == inputLength &&
|
||||
body.Ranges.OutputBegin == 0 && body.Ranges.OutputEnd <= inputLength);
|
||||
|
||||
return body.Ranges.OutputEnd;
|
||||
}
|
||||
|
||||
#ifdef _VTKM_DEBUG_TBB_RBK
|
||||
#undef _VTKM_DEBUG_TBB_RBK
|
||||
#endif
|
||||
|
||||
template <class InputPortalType, class OutputPortalType, class BinaryOperationType>
|
||||
struct ScanInclusiveBody
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user