Merge topic '173_tbb_reducebykey'

d174c0fe Add TBB specialization for ReduceByKey.
c9c7149c Fix typo in ReduceByKey docstring.

Acked-by: Kitware Robot <kwrobot@kitware.com>
Acked-by: Kenneth Moreland <kmorel@sandia.gov>
Merge-request: !928
This commit is contained in:
Allison Vacanti 2017-09-18 17:42:51 +00:00 committed by Kitware Robot
commit 6c7c4ddec0
3 changed files with 382 additions and 1 deletions

@ -203,7 +203,7 @@ struct DeviceAdapterAlgorithm
VTKM_CONT static void ReduceByKey(const vtkm::cont::ArrayHandle<T, CKeyIn>& keys,
const vtkm::cont::ArrayHandle<U, CValIn>& values,
vtkm::cont::ArrayHandle<T, CKeyOut>& keys_output,
vtkm::cont::ArrayHandle<T, CValOut>& values_output,
vtkm::cont::ArrayHandle<U, CValOut>& values_output,
BinaryFunctor binary_functor);
/// \brief Compute an inclusive prefix sum operation on the input ArrayHandle.

@ -60,6 +60,31 @@ public:
input.PrepareForInput(vtkm::cont::DeviceAdapterTagTBB()), initialValue, binary_functor);
}
template <typename T,
typename U,
class CKeyIn,
class CValIn,
class CKeyOut,
class CValOut,
class BinaryFunctor>
VTKM_CONT static void ReduceByKey(const vtkm::cont::ArrayHandle<T, CKeyIn>& keys,
const vtkm::cont::ArrayHandle<U, CValIn>& values,
vtkm::cont::ArrayHandle<T, CKeyOut>& keys_output,
vtkm::cont::ArrayHandle<U, CValOut>& values_output,
BinaryFunctor binary_functor)
{
vtkm::Id inputSize = keys.GetNumberOfValues();
VTKM_ASSERT(inputSize == values.GetNumberOfValues());
vtkm::Id outputSize =
tbb::ReduceByKeyPortals(keys.PrepareForInput(DeviceAdapterTagTBB()),
values.PrepareForInput(DeviceAdapterTagTBB()),
keys_output.PrepareForOutput(inputSize, DeviceAdapterTagTBB()),
values_output.PrepareForOutput(inputSize, DeviceAdapterTagTBB()),
binary_functor);
keys_output.Shrink(outputSize);
values_output.Shrink(outputSize);
}
template <typename T, class CIn, class COut>
VTKM_CONT static T ScanInclusive(const vtkm::cont::ArrayHandle<T, CIn>& input,
vtkm::cont::ArrayHandle<T, COut>& output)

@ -27,6 +27,9 @@
#include <vtkm/cont/internal/FunctorsGeneral.h>
#include <vtkm/exec/internal/ErrorMessageBuffer.h>
#include <algorithm>
#include <sstream>
VTKM_THIRDPARTY_PRE_INCLUDE
#if defined(VTKM_MSVC)
@ -195,6 +198,359 @@ VTKM_CONT static T ReducePortals(InputPortalType inputPortal,
}
}
// Define this to print out timing information from the reduction and join
// operations in the tbb ReduceByKey algorithm:
//#define _VTKM_DEBUG_TBB_RBK
template <typename KeysInPortalType,
typename ValuesInPortalType,
typename KeysOutPortalType,
typename ValuesOutPortalType,
class BinaryOperationType>
struct ReduceByKeyBody
{
using KeyType = typename KeysInPortalType::ValueType;
using ValueType = typename ValuesInPortalType::ValueType;
struct Range
{
vtkm::Id InputBegin;
vtkm::Id InputEnd;
vtkm::Id OutputBegin;
vtkm::Id OutputEnd;
VTKM_EXEC_CONT
Range()
: InputBegin(-1)
, InputEnd(-1)
, OutputBegin(-1)
, OutputEnd(-1)
{
}
VTKM_EXEC_CONT
Range(vtkm::Id inputBegin, vtkm::Id inputEnd, vtkm::Id outputBegin, vtkm::Id outputEnd)
: InputBegin(inputBegin)
, InputEnd(inputEnd)
, OutputBegin(outputBegin)
, OutputEnd(outputEnd)
{
this->AssertSane();
}
VTKM_EXEC_CONT
void AssertSane() const
{
VTKM_ASSERT("Input begin precedes end" && this->InputBegin <= this->InputEnd);
VTKM_ASSERT("Output begin precedes end" && this->OutputBegin <= this->OutputEnd);
VTKM_ASSERT("Output not past input" && this->OutputBegin <= this->InputBegin &&
this->OutputEnd <= this->InputEnd);
VTKM_ASSERT("Output smaller than input" &&
(this->OutputEnd - this->OutputBegin) <= (this->InputEnd - this->InputBegin));
}
VTKM_EXEC_CONT
bool IsNext(const Range& next) const { return this->InputEnd == next.InputBegin; }
};
KeysInPortalType KeysInPortal;
ValuesInPortalType ValuesInPortal;
KeysOutPortalType KeysOutPortal;
ValuesOutPortalType ValuesOutPortal;
BinaryOperationType BinaryOperation;
Range Ranges;
#ifdef _VTKM_DEBUG_TBB_RBK
double ReduceTime;
double JoinTime;
#endif
VTKM_CONT
ReduceByKeyBody(const KeysInPortalType& keysInPortal,
const ValuesInPortalType& valuesInPortal,
const KeysOutPortalType& keysOutPortal,
const ValuesOutPortalType& valuesOutPortal,
BinaryOperationType binaryOperation)
: KeysInPortal(keysInPortal)
, ValuesInPortal(valuesInPortal)
, KeysOutPortal(keysOutPortal)
, ValuesOutPortal(valuesOutPortal)
, BinaryOperation(binaryOperation)
#ifdef _VTKM_DEBUG_TBB_RBK
, ReduceTime(0)
, JoinTime(0)
#endif
{
}
VTKM_EXEC_CONT
ReduceByKeyBody(const ReduceByKeyBody& body, ::tbb::split)
: KeysInPortal(body.KeysInPortal)
, ValuesInPortal(body.ValuesInPortal)
, KeysOutPortal(body.KeysOutPortal)
, ValuesOutPortal(body.ValuesOutPortal)
, BinaryOperation(body.BinaryOperation)
#ifdef _VTKM_DEBUG_TBB_RBK
, ReduceTime(0)
, JoinTime(0)
#endif
{
}
VTKM_SUPPRESS_EXEC_WARNINGS
VTKM_EXEC
void operator()(const ::tbb::blocked_range<vtkm::Id>& range)
{
#ifdef _VTKM_DEBUG_TBB_RBK
::tbb::tick_count startTime = ::tbb::tick_count::now();
#endif // _VTKM_DEBUG_TBB_RBK
if (range.empty())
{
return;
}
bool firstRun = this->Ranges.OutputBegin < 0; // First use of this body object
if (firstRun)
{
this->Ranges.InputBegin = range.begin();
}
else
{
// Must be a continuation of the previous input range:
VTKM_ASSERT(this->Ranges.InputEnd == range.begin());
}
this->Ranges.InputEnd = range.end();
this->Ranges.AssertSane();
using KeysInIteratorsType = vtkm::cont::ArrayPortalToIterators<KeysInPortalType>;
using ValuesInIteratorsType = vtkm::cont::ArrayPortalToIterators<ValuesInPortalType>;
using KeysOutIteratorsType = vtkm::cont::ArrayPortalToIterators<KeysOutPortalType>;
using ValuesOutIteratorsType = vtkm::cont::ArrayPortalToIterators<ValuesOutPortalType>;
KeysInIteratorsType keysInIters(this->KeysInPortal);
ValuesInIteratorsType valuesInIters(this->ValuesInPortal);
KeysOutIteratorsType keysOutIters(this->KeysOutPortal);
ValuesOutIteratorsType valuesOutIters(this->ValuesOutPortal);
using KeysInIteratorType = typename KeysInIteratorsType::IteratorType;
using ValuesInIteratorType = typename ValuesInIteratorsType::IteratorType;
using KeysOutIteratorType = typename KeysOutIteratorsType::IteratorType;
using ValuesOutIteratorType = typename ValuesOutIteratorsType::IteratorType;
KeysInIteratorType keysIn = keysInIters.GetBegin();
ValuesInIteratorType valuesIn = valuesInIters.GetBegin();
KeysOutIteratorType keysOut = keysOutIters.GetBegin();
ValuesOutIteratorType valuesOut = valuesOutIters.GetBegin();
vtkm::Id readPos = range.begin();
const vtkm::Id readEnd = range.end();
// Determine output index. If we're reusing the body, pick up where the
// last block left off. If not, use the input range.
vtkm::Id writePos;
if (firstRun)
{
this->Ranges.OutputBegin = range.begin();
this->Ranges.OutputEnd = range.begin();
writePos = range.begin();
}
else
{
writePos = this->Ranges.OutputEnd;
}
this->Ranges.AssertSane();
// We're either writing at the end of a previous block, or at the input
// location. Either way, the write position will never be greater than
// the read position.
VTKM_ASSERT(writePos <= readPos);
// Initialize reduction variables:
BinaryOperationType functor(this->BinaryOperation);
KeyType currentKey = keysIn[readPos];
ValueType currentValue = valuesIn[readPos];
++readPos;
// If the start of the current range continues a previous key block,
// initialize with the previous result and decrement the write index.
VTKM_ASSERT(firstRun || writePos > 0);
if (!firstRun && keysOut[writePos - 1] == currentKey)
{
// Ensure that we'll overwrite the continued key values:
--writePos;
// Update our accumulator with the partial value:
currentValue = functor(valuesOut[writePos], currentValue);
}
// Special case: single value in range
if (readPos >= readEnd)
{
keysOut[writePos] = currentKey;
valuesOut[writePos] = currentValue;
++writePos;
this->Ranges.OutputEnd = writePos;
return;
}
for (;;)
{
while (readPos < readEnd && currentKey == keysIn[readPos])
{
currentValue = functor(currentValue, valuesIn[readPos]);
++readPos;
}
VTKM_ASSERT(writePos <= readPos);
keysOut[writePos] = currentKey;
valuesOut[writePos] = currentValue;
++writePos;
if (readPos < readEnd)
{
currentKey = keysIn[readPos];
currentValue = valuesIn[readPos];
++readPos;
continue;
}
break;
}
this->Ranges.OutputEnd = writePos;
#ifdef _VTKM_DEBUG_TBB_RBK
::tbb::tick_count endTime = ::tbb::tick_count::now();
double time = (endTime - startTime).seconds();
this->ReduceTime += time;
std::ostringstream out;
out << "Reduced " << range.size() << " key/value pairs in " << time << "s. "
<< "InRange: " << this->Ranges.InputBegin << " " << this->Ranges.InputEnd << " "
<< "OutRange: " << this->Ranges.OutputBegin << " " << this->Ranges.OutputEnd << "\n";
std::cerr << out.str();
#endif
}
VTKM_SUPPRESS_EXEC_WARNINGS
VTKM_EXEC_CONT
void join(const ReduceByKeyBody& rhs)
{
using KeysIteratorsType = vtkm::cont::ArrayPortalToIterators<KeysOutPortalType>;
using ValuesIteratorsType = vtkm::cont::ArrayPortalToIterators<ValuesOutPortalType>;
using KeysIteratorType = typename KeysIteratorsType::IteratorType;
using ValuesIteratorType = typename ValuesIteratorsType::IteratorType;
#ifdef _VTKM_DEBUG_TBB_RBK
::tbb::tick_count startTime = ::tbb::tick_count::now();
#endif
this->Ranges.AssertSane();
rhs.Ranges.AssertSane();
// Ensure that we're joining two consecutive subsets of the input:
VTKM_ASSERT(this->Ranges.IsNext(rhs.Ranges));
KeysIteratorsType keysIters(this->KeysOutPortal);
ValuesIteratorsType valuesIters(this->ValuesOutPortal);
KeysIteratorType keys = keysIters.GetBegin();
ValuesIteratorType values = valuesIters.GetBegin();
const vtkm::Id dstBegin = this->Ranges.OutputEnd;
const vtkm::Id lastDstIdx = this->Ranges.OutputEnd - 1;
vtkm::Id srcBegin = rhs.Ranges.OutputBegin;
const vtkm::Id srcEnd = rhs.Ranges.OutputEnd;
// Merge boundaries if needed:
if (keys[srcBegin] == keys[lastDstIdx])
{
values[lastDstIdx] = this->BinaryOperation(values[lastDstIdx], values[srcBegin]);
++srcBegin; // Don't copy the key/value we just reduced
}
// move data:
if (srcBegin != dstBegin && srcBegin != srcEnd)
{
// Sanity check:
VTKM_ASSERT(srcBegin < srcEnd);
// Not necessary for the copy call to be safe, but if the src range
// overlaps with the dst range there's a problem with the algorithm:
VTKM_ASSERT(dstBegin + (srcEnd - srcBegin) <= srcBegin);
std::copy(keys + srcBegin, keys + srcEnd, keys + dstBegin);
std::copy(values + srcBegin, values + srcEnd, values + dstBegin);
}
this->Ranges.InputEnd = rhs.Ranges.InputEnd;
this->Ranges.OutputEnd += srcEnd - srcBegin;
this->Ranges.AssertSane();
#ifdef _VTKM_DEBUG_TBB_RBK
::tbb::tick_count endTime = ::tbb::tick_count::now();
double time = (endTime - startTime).seconds();
this->JoinTime += rhs.JoinTime + time;
std::ostringstream out;
out << "Joined " << (srcEnd - srcBegin) << " rhs values into body in " << time << "s. "
<< "InRange: " << this->Ranges.InputBegin << " " << this->Ranges.InputEnd << " "
<< "OutRange: " << this->Ranges.OutputBegin << " " << this->Ranges.OutputEnd << "\n";
std::cerr << out.str();
#endif
}
};
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename KeysInPortalType,
typename ValuesInPortalType,
typename KeysOutPortalType,
typename ValuesOutPortalType,
typename BinaryOperationType>
VTKM_CONT vtkm::Id ReduceByKeyPortals(KeysInPortalType keysInPortal,
ValuesInPortalType valuesInPortal,
KeysOutPortalType keysOutPortal,
ValuesOutPortalType valuesOutPortal,
BinaryOperationType binaryOperation)
{
const vtkm::Id inputLength = keysInPortal.GetNumberOfValues();
VTKM_ASSERT(inputLength == valuesInPortal.GetNumberOfValues());
if (inputLength == 0)
{
return 0;
}
using ValueType = typename ValuesInPortalType::ValueType;
using WrappedBinaryOp = internal::WrappedBinaryOperator<ValueType, BinaryOperationType>;
WrappedBinaryOp wrappedBinaryOp(binaryOperation);
ReduceByKeyBody<KeysInPortalType,
ValuesInPortalType,
KeysOutPortalType,
ValuesOutPortalType,
WrappedBinaryOp>
body(keysInPortal, valuesInPortal, keysOutPortal, valuesOutPortal, wrappedBinaryOp);
::tbb::blocked_range<vtkm::Id> range(0, inputLength, TBB_GRAIN_SIZE);
#ifdef _VTKM_DEBUG_TBB_RBK
std::cerr << "\n\nTBB ReduceByKey:\n";
#endif
::tbb::parallel_reduce(range, body);
#ifdef _VTKM_DEBUG_TBB_RBK
std::cerr << "Total reduce time: " << body.ReduceTime << "s\n";
std::cerr << "Total join time: " << body.JoinTime << "s\n";
std::cerr << "\nend\n";
#endif
body.Ranges.AssertSane();
VTKM_ASSERT(body.Ranges.InputBegin == 0 && body.Ranges.InputEnd == inputLength &&
body.Ranges.OutputBegin == 0 && body.Ranges.OutputEnd <= inputLength);
return body.Ranges.OutputEnd;
}
#ifdef _VTKM_DEBUG_TBB_RBK
#undef _VTKM_DEBUG_TBB_RBK
#endif
template <class InputPortalType, class OutputPortalType, class BinaryOperationType>
struct ScanInclusiveBody
{