Fix ReduceByKey general algorithm to work with parallel ScanInclusive. Use bit representation for states.
This commit is contained in:
parent
50ffcec3bb
commit
dcc12e733f
@ -491,15 +491,15 @@ private:
|
||||
|
||||
struct ReduceKeySeriesStates
|
||||
{
|
||||
//It is needed that END and START_AND_END are both odd numbers
|
||||
//so that the first bit of both are 1
|
||||
enum { MIDDLE=0, END=1, START=2, START_AND_END=3};
|
||||
bool fStart; // START of a segment
|
||||
bool fEnd; // END of a segment
|
||||
ReduceKeySeriesStates(bool start=false, bool end=false) : fStart(start), fEnd(end) {}
|
||||
};
|
||||
|
||||
template<typename InputPortalType>
|
||||
struct ReduceStencilGeneration : vtkm::exec::FunctorBase
|
||||
{
|
||||
typedef typename vtkm::cont::ArrayHandle< vtkm::UInt8 >::template ExecutionTypes<DeviceAdapterTag>
|
||||
typedef typename vtkm::cont::ArrayHandle< ReduceKeySeriesStates >::template ExecutionTypes<DeviceAdapterTag>
|
||||
::Portal KeyStatePortalType;
|
||||
|
||||
InputPortalType Input;
|
||||
@ -540,8 +540,7 @@ private:
|
||||
//just need to check if we are END
|
||||
const ValueType centerValue = this->Input.Get(centerIndex);
|
||||
const ValueType rightValue = this->Input.Get(rightIndex);
|
||||
const KeyStateType state = (rightValue == centerValue) ? States::START :
|
||||
States::START_AND_END;
|
||||
const KeyStateType state = ReduceKeySeriesStates(true, rightValue != centerValue);
|
||||
this->KeyState.Set(centerIndex, state);
|
||||
}
|
||||
else if(rightIndex == this->Input.GetNumberOfValues())
|
||||
@ -550,8 +549,7 @@ private:
|
||||
//just need to check if we are START
|
||||
const ValueType centerValue = this->Input.Get(centerIndex);
|
||||
const ValueType leftValue = this->Input.Get(leftIndex);
|
||||
const KeyStateType state = (leftValue == centerValue) ? States::END :
|
||||
States::START_AND_END;
|
||||
const KeyStateType state = ReduceKeySeriesStates(leftValue != centerValue, true);
|
||||
this->KeyState.Set(centerIndex, state);
|
||||
}
|
||||
else
|
||||
@ -561,19 +559,7 @@ private:
|
||||
const bool rightMatches(this->Input.Get(rightIndex) == centerValue);
|
||||
|
||||
//assume it is the middle, and check for the other use-case
|
||||
KeyStateType state = States::MIDDLE;
|
||||
if(!leftMatches && rightMatches)
|
||||
{
|
||||
state = States::START;
|
||||
}
|
||||
else if(leftMatches && !rightMatches)
|
||||
{
|
||||
state = States::END;
|
||||
}
|
||||
else if(!leftMatches && !rightMatches)
|
||||
{
|
||||
state = States::START_AND_END;
|
||||
}
|
||||
KeyStateType state = ReduceKeySeriesStates(!leftMatches, !rightMatches);
|
||||
this->KeyState.Set(centerIndex, state);
|
||||
}
|
||||
}
|
||||
@ -582,16 +568,21 @@ private:
|
||||
struct ReduceByKeyAdd
|
||||
{
|
||||
template<typename T>
|
||||
vtkm::Pair<T, vtkm::UInt8> operator()(const vtkm::Pair<T, vtkm::UInt8>& a,
|
||||
const vtkm::Pair<T, vtkm::UInt8>& b) const
|
||||
vtkm::Pair<T, ReduceKeySeriesStates> operator()(const vtkm::Pair<T, ReduceKeySeriesStates>& a,
|
||||
const vtkm::Pair<T, ReduceKeySeriesStates>& b) const
|
||||
{
|
||||
typedef vtkm::Pair<T, vtkm::UInt8> ReturnType;
|
||||
typedef vtkm::Pair<T, ReduceKeySeriesStates> ReturnType;
|
||||
typedef ReduceKeySeriesStates States;
|
||||
//need too handle how we are going to add two numbers together
|
||||
//based on the keyStates that they have
|
||||
|
||||
//need to optimize this logic, we can use a bit mask to determine
|
||||
//the secondary value.
|
||||
#if 1 // Make it work for parallel inclusive scan. Will end up with all start bits = 1
|
||||
if (!b.second.fStart) // is b is not START, then it's safe to add. Propagate a's start flag to b
|
||||
return ReturnType(a.first + b.first, ReduceKeySeriesStates(a.second.fStart, b.second.fEnd));
|
||||
return b;
|
||||
#else // Works only for sequencial scan
|
||||
if(a.second == States::START && b.second == States::END)
|
||||
{
|
||||
return ReturnType(a.first + b.first, States::START_AND_END); //with second type as START_AND_END
|
||||
@ -607,17 +598,16 @@ private:
|
||||
{
|
||||
return b;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct ReduceByKeyUnaryStencilOp
|
||||
{
|
||||
bool operator()(vtkm::UInt8 keySeriesState) const
|
||||
bool operator()(ReduceKeySeriesStates keySeriesState) const
|
||||
{
|
||||
typedef ReduceKeySeriesStates States;
|
||||
return (keySeriesState == States::END ||
|
||||
keySeriesState == States::START_AND_END);
|
||||
return keySeriesState.fEnd;
|
||||
}
|
||||
|
||||
};
|
||||
@ -647,13 +637,13 @@ public:
|
||||
//we need to determine based on the keys what is the keystate for
|
||||
//each key. The states are start, middle, end of a series and the special
|
||||
//state start and end of a series
|
||||
vtkm::cont::ArrayHandle< vtkm::UInt8 > keystate;
|
||||
vtkm::cont::ArrayHandle< ReduceKeySeriesStates > keystate;
|
||||
|
||||
{
|
||||
typedef typename vtkm::cont::ArrayHandle<T,KIn>::template ExecutionTypes<DeviceAdapterTag>
|
||||
::PortalConst InputPortalType;
|
||||
|
||||
typedef typename vtkm::cont::ArrayHandle< vtkm::UInt8 >::template ExecutionTypes<DeviceAdapterTag>
|
||||
typedef typename vtkm::cont::ArrayHandle< ReduceKeySeriesStates >::template ExecutionTypes<DeviceAdapterTag>
|
||||
::Portal KeyStatePortalType;
|
||||
|
||||
InputPortalType inputPortal = keys.PrepareForInput(DeviceAdapterTag());
|
||||
@ -672,7 +662,7 @@ public:
|
||||
{
|
||||
typedef vtkm::cont::ArrayHandle<U,VIn> ValueInHandleType;
|
||||
typedef vtkm::cont::ArrayHandle<U,VOut> ValueOutHandleType;
|
||||
typedef vtkm::cont::ArrayHandle< vtkm::UInt8> StencilHandleType;
|
||||
typedef vtkm::cont::ArrayHandle< ReduceKeySeriesStates> StencilHandleType;
|
||||
typedef vtkm::cont::ArrayHandleZip<ValueInHandleType,
|
||||
StencilHandleType> ZipInHandleType;
|
||||
typedef vtkm::cont::ArrayHandleZip<ValueOutHandleType,
|
||||
|
@ -1073,7 +1073,7 @@ private:
|
||||
const vtkm::Id k = keysOut.GetPortalConstControl().Get(i);
|
||||
const vtkm::Id v = valuesOut.GetPortalConstControl().Get(i);
|
||||
VTKM_TEST_ASSERT( expectedKeys[i] == k, "Incorrect reduced key");
|
||||
VTKM_TEST_ASSERT( expectedValues[i] == v, "Incorrect reduced vale");
|
||||
VTKM_TEST_ASSERT( expectedValues[i] == v, "Incorrect reduced value");
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user