Fix ReduceByKey general algorithm to work with parallel ScanInclusive. Use bit representation for states.

This commit is contained in:
Chun-Ming Chen 2015-06-12 16:38:03 -04:00 committed by Chunming Chen
parent 50ffcec3bb
commit dcc12e733f
2 changed files with 22 additions and 32 deletions

@ -491,15 +491,15 @@ private:
struct ReduceKeySeriesStates
{
//It is needed that END and START_AND_END are both odd numbers
//so that the first bit of both are 1
enum { MIDDLE=0, END=1, START=2, START_AND_END=3};
bool fStart; // START of a segment
bool fEnd; // END of a segment
ReduceKeySeriesStates(bool start=false, bool end=false) : fStart(start), fEnd(end) {}
};
template<typename InputPortalType>
struct ReduceStencilGeneration : vtkm::exec::FunctorBase
{
typedef typename vtkm::cont::ArrayHandle< vtkm::UInt8 >::template ExecutionTypes<DeviceAdapterTag>
typedef typename vtkm::cont::ArrayHandle< ReduceKeySeriesStates >::template ExecutionTypes<DeviceAdapterTag>
::Portal KeyStatePortalType;
InputPortalType Input;
@ -540,8 +540,7 @@ private:
//just need to check if we are END
const ValueType centerValue = this->Input.Get(centerIndex);
const ValueType rightValue = this->Input.Get(rightIndex);
const KeyStateType state = (rightValue == centerValue) ? States::START :
States::START_AND_END;
const KeyStateType state = ReduceKeySeriesStates(true, rightValue != centerValue);
this->KeyState.Set(centerIndex, state);
}
else if(rightIndex == this->Input.GetNumberOfValues())
@ -550,8 +549,7 @@ private:
//just need to check if we are START
const ValueType centerValue = this->Input.Get(centerIndex);
const ValueType leftValue = this->Input.Get(leftIndex);
const KeyStateType state = (leftValue == centerValue) ? States::END :
States::START_AND_END;
const KeyStateType state = ReduceKeySeriesStates(leftValue != centerValue, true);
this->KeyState.Set(centerIndex, state);
}
else
@ -561,19 +559,7 @@ private:
const bool rightMatches(this->Input.Get(rightIndex) == centerValue);
//assume it is the middle, and check for the other use-case
KeyStateType state = States::MIDDLE;
if(!leftMatches && rightMatches)
{
state = States::START;
}
else if(leftMatches && !rightMatches)
{
state = States::END;
}
else if(!leftMatches && !rightMatches)
{
state = States::START_AND_END;
}
KeyStateType state = ReduceKeySeriesStates(!leftMatches, !rightMatches);
this->KeyState.Set(centerIndex, state);
}
}
@ -582,16 +568,21 @@ private:
struct ReduceByKeyAdd
{
template<typename T>
vtkm::Pair<T, vtkm::UInt8> operator()(const vtkm::Pair<T, vtkm::UInt8>& a,
const vtkm::Pair<T, vtkm::UInt8>& b) const
vtkm::Pair<T, ReduceKeySeriesStates> operator()(const vtkm::Pair<T, ReduceKeySeriesStates>& a,
const vtkm::Pair<T, ReduceKeySeriesStates>& b) const
{
typedef vtkm::Pair<T, vtkm::UInt8> ReturnType;
typedef vtkm::Pair<T, ReduceKeySeriesStates> ReturnType;
typedef ReduceKeySeriesStates States;
//need too handle how we are going to add two numbers together
//based on the keyStates that they have
//need to optimize this logic, we can use a bit mask to determine
//the secondary value.
#if 1 // Make it work for parallel inclusive scan. Will end up with all start bits = 1
if (!b.second.fStart) // is b is not START, then it's safe to add. Propagate a's start flag to b
return ReturnType(a.first + b.first, ReduceKeySeriesStates(a.second.fStart, b.second.fEnd));
return b;
#else // Works only for sequencial scan
if(a.second == States::START && b.second == States::END)
{
return ReturnType(a.first + b.first, States::START_AND_END); //with second type as START_AND_END
@ -607,17 +598,16 @@ private:
{
return b;
}
#endif
}
};
struct ReduceByKeyUnaryStencilOp
{
bool operator()(vtkm::UInt8 keySeriesState) const
bool operator()(ReduceKeySeriesStates keySeriesState) const
{
typedef ReduceKeySeriesStates States;
return (keySeriesState == States::END ||
keySeriesState == States::START_AND_END);
return keySeriesState.fEnd;
}
};
@ -647,13 +637,13 @@ public:
//we need to determine based on the keys what is the keystate for
//each key. The states are start, middle, end of a series and the special
//state start and end of a series
vtkm::cont::ArrayHandle< vtkm::UInt8 > keystate;
vtkm::cont::ArrayHandle< ReduceKeySeriesStates > keystate;
{
typedef typename vtkm::cont::ArrayHandle<T,KIn>::template ExecutionTypes<DeviceAdapterTag>
::PortalConst InputPortalType;
typedef typename vtkm::cont::ArrayHandle< vtkm::UInt8 >::template ExecutionTypes<DeviceAdapterTag>
typedef typename vtkm::cont::ArrayHandle< ReduceKeySeriesStates >::template ExecutionTypes<DeviceAdapterTag>
::Portal KeyStatePortalType;
InputPortalType inputPortal = keys.PrepareForInput(DeviceAdapterTag());
@ -672,7 +662,7 @@ public:
{
typedef vtkm::cont::ArrayHandle<U,VIn> ValueInHandleType;
typedef vtkm::cont::ArrayHandle<U,VOut> ValueOutHandleType;
typedef vtkm::cont::ArrayHandle< vtkm::UInt8> StencilHandleType;
typedef vtkm::cont::ArrayHandle< ReduceKeySeriesStates> StencilHandleType;
typedef vtkm::cont::ArrayHandleZip<ValueInHandleType,
StencilHandleType> ZipInHandleType;
typedef vtkm::cont::ArrayHandleZip<ValueOutHandleType,

@ -1073,7 +1073,7 @@ private:
const vtkm::Id k = keysOut.GetPortalConstControl().Get(i);
const vtkm::Id v = valuesOut.GetPortalConstControl().Get(i);
VTKM_TEST_ASSERT( expectedKeys[i] == k, "Incorrect reduced key");
VTKM_TEST_ASSERT( expectedValues[i] == v, "Incorrect reduced vale");
VTKM_TEST_ASSERT( expectedValues[i] == v, "Incorrect reduced value");
}
}