From dcc12e733fcc8de77a3f35279b7a74e5201152d5 Mon Sep 17 00:00:00 2001 From: Chun-Ming Chen Date: Fri, 12 Jun 2015 16:38:03 -0400 Subject: [PATCH] Fix ReduceByKey general algorithm to work with parallel ScanInclusive. Use bit representation for states. --- .../internal/DeviceAdapterAlgorithmGeneral.h | 52 ++++++++----------- vtkm/cont/testing/TestingDeviceAdapter.h | 2 +- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/vtkm/cont/internal/DeviceAdapterAlgorithmGeneral.h b/vtkm/cont/internal/DeviceAdapterAlgorithmGeneral.h index 8393bd19f..60cd8d3ac 100644 --- a/vtkm/cont/internal/DeviceAdapterAlgorithmGeneral.h +++ b/vtkm/cont/internal/DeviceAdapterAlgorithmGeneral.h @@ -491,15 +491,15 @@ private: struct ReduceKeySeriesStates { - //It is needed that END and START_AND_END are both odd numbers - //so that the first bit of both are 1 - enum { MIDDLE=0, END=1, START=2, START_AND_END=3}; + bool fStart; // START of a segment + bool fEnd; // END of a segment + ReduceKeySeriesStates(bool start=false, bool end=false) : fStart(start), fEnd(end) {} }; template struct ReduceStencilGeneration : vtkm::exec::FunctorBase { - typedef typename vtkm::cont::ArrayHandle< vtkm::UInt8 >::template ExecutionTypes + typedef typename vtkm::cont::ArrayHandle< ReduceKeySeriesStates >::template ExecutionTypes ::Portal KeyStatePortalType; InputPortalType Input; @@ -540,8 +540,7 @@ private: //just need to check if we are END const ValueType centerValue = this->Input.Get(centerIndex); const ValueType rightValue = this->Input.Get(rightIndex); - const KeyStateType state = (rightValue == centerValue) ? States::START : - States::START_AND_END; + const KeyStateType state = ReduceKeySeriesStates(true, rightValue != centerValue); this->KeyState.Set(centerIndex, state); } else if(rightIndex == this->Input.GetNumberOfValues()) @@ -550,8 +549,7 @@ private: //just need to check if we are START const ValueType centerValue = this->Input.Get(centerIndex); const ValueType leftValue = this->Input.Get(leftIndex); - const KeyStateType state = (leftValue == centerValue) ? States::END : - States::START_AND_END; + const KeyStateType state = ReduceKeySeriesStates(leftValue != centerValue, true); this->KeyState.Set(centerIndex, state); } else @@ -561,19 +559,7 @@ private: const bool rightMatches(this->Input.Get(rightIndex) == centerValue); //assume it is the middle, and check for the other use-case - KeyStateType state = States::MIDDLE; - if(!leftMatches && rightMatches) - { - state = States::START; - } - else if(leftMatches && !rightMatches) - { - state = States::END; - } - else if(!leftMatches && !rightMatches) - { - state = States::START_AND_END; - } + KeyStateType state = ReduceKeySeriesStates(!leftMatches, !rightMatches); this->KeyState.Set(centerIndex, state); } } @@ -582,16 +568,21 @@ private: struct ReduceByKeyAdd { template - vtkm::Pair operator()(const vtkm::Pair& a, - const vtkm::Pair& b) const + vtkm::Pair operator()(const vtkm::Pair& a, + const vtkm::Pair& b) const { - typedef vtkm::Pair ReturnType; + typedef vtkm::Pair ReturnType; typedef ReduceKeySeriesStates States; //need too handle how we are going to add two numbers together //based on the keyStates that they have //need to optimize this logic, we can use a bit mask to determine //the secondary value. +#if 1 // Make it work for parallel inclusive scan. Will end up with all start bits = 1 + if (!b.second.fStart) // is b is not START, then it's safe to add. Propagate a's start flag to b + return ReturnType(a.first + b.first, ReduceKeySeriesStates(a.second.fStart, b.second.fEnd)); + return b; +#else // Works only for sequencial scan if(a.second == States::START && b.second == States::END) { return ReturnType(a.first + b.first, States::START_AND_END); //with second type as START_AND_END @@ -607,17 +598,16 @@ private: { return b; } +#endif } }; struct ReduceByKeyUnaryStencilOp { - bool operator()(vtkm::UInt8 keySeriesState) const + bool operator()(ReduceKeySeriesStates keySeriesState) const { - typedef ReduceKeySeriesStates States; - return (keySeriesState == States::END || - keySeriesState == States::START_AND_END); + return keySeriesState.fEnd; } }; @@ -647,13 +637,13 @@ public: //we need to determine based on the keys what is the keystate for //each key. The states are start, middle, end of a series and the special //state start and end of a series - vtkm::cont::ArrayHandle< vtkm::UInt8 > keystate; + vtkm::cont::ArrayHandle< ReduceKeySeriesStates > keystate; { typedef typename vtkm::cont::ArrayHandle::template ExecutionTypes ::PortalConst InputPortalType; - typedef typename vtkm::cont::ArrayHandle< vtkm::UInt8 >::template ExecutionTypes + typedef typename vtkm::cont::ArrayHandle< ReduceKeySeriesStates >::template ExecutionTypes ::Portal KeyStatePortalType; InputPortalType inputPortal = keys.PrepareForInput(DeviceAdapterTag()); @@ -672,7 +662,7 @@ public: { typedef vtkm::cont::ArrayHandle ValueInHandleType; typedef vtkm::cont::ArrayHandle ValueOutHandleType; - typedef vtkm::cont::ArrayHandle< vtkm::UInt8> StencilHandleType; + typedef vtkm::cont::ArrayHandle< ReduceKeySeriesStates> StencilHandleType; typedef vtkm::cont::ArrayHandleZip ZipInHandleType; typedef vtkm::cont::ArrayHandleZip