//============================================================================ // Copyright (c) Kitware, Inc. // All rights reserved. // See LICENSE.txt for details. // This software is distributed WITHOUT ANY WARRANTY; without even // the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // PURPOSE. See the above copyright notice for more information. // // Copyright 2014 Sandia Corporation. // Copyright 2014 UT-Battelle, LLC. // Copyright 2014 Los Alamos National Security. // // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation, // the U.S. Government retains certain rights in this software. // // Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National // Laboratory (LANL), the U.S. Government retains certain rights in // this software. //============================================================================ #ifndef vtk_m_cont_internal_DeviceAdapterAlgorithmGeneral_h #define vtk_m_cont_internal_DeviceAdapterAlgorithmGeneral_h #include #include #include #include #include #include #include #include #include #include namespace vtkm { namespace cont { namespace internal { // Binary function object wrapper which can detect and handle calling the // wrapped operator with complex value types such as // IteratorFromArrayPortalValue which happen when passed an input array that // is implicit. template struct WrappedBinaryOperator { Function m_f; VTKM_CONT_EXPORT WrappedBinaryOperator(const Function &f) : m_f(f) {} template VTKM_CONT_EXPORT ResultType operator()(const Argument1 &x, const Argument2 &y) const { return m_f(x, y); } template VTKM_CONT_EXPORT ResultType operator()( const detail::IteratorFromArrayPortalValue &x, const detail::IteratorFromArrayPortalValue &y) const { typedef typename detail::IteratorFromArrayPortalValue::ValueType ValueTypeX; typedef typename detail::IteratorFromArrayPortalValue::ValueType ValueTypeY; return m_f( (ValueTypeX)x, (ValueTypeY)y ); } template VTKM_CONT_EXPORT ResultType operator()( const Argument1 &x, const detail::IteratorFromArrayPortalValue &y) const { typedef typename detail::IteratorFromArrayPortalValue::ValueType ValueTypeY; return m_f( x, (ValueTypeY)y ); } template VTKM_CONT_EXPORT ResultType operator()( const detail::IteratorFromArrayPortalValue &x, const Argument2 &y) const { typedef typename detail::IteratorFromArrayPortalValue::ValueType ValueTypeX; return m_f( (ValueTypeX)x, y ); } }; /// \brief /// /// This struct provides algorithms that implement "general" device adapter /// algorithms. If a device adapter provides implementations for Schedule, /// and Synchronize, the rest of the algorithms can be implemented by calling /// these functions. /// /// It should be noted that we recommend that you also implement Sort, /// ScanInclusive, and ScanExclusive for improved performance. /// /// An easy way to implement the DeviceAdapterAlgorithm specialization is to /// subclass this and override the implementation of methods as necessary. /// As an example, the code would look something like this. /// /// \code{.cpp} /// template<> /// struct DeviceAdapterAlgorithm /// : DeviceAdapterAlgorithmGeneral, /// DeviceAdapterTagFoo> /// { /// template /// VTKM_CONT_EXPORT static void Schedule(Functor functor, /// vtkm::Id numInstances) /// { /// ... /// } /// /// template /// VTKM_CONT_EXPORT static void Schedule(Functor functor, /// vtkm::Id3 maxRange) /// { /// ... /// } /// /// VTKM_CONT_EXPORT static void Synchronize() /// { /// ... /// } /// }; /// \endcode /// /// You might note that DeviceAdapterAlgorithmGeneral has two template /// parameters that are redundant. Although the first parameter, the class for /// the actual DeviceAdapterAlgorithm class containing Schedule, and /// Synchronize is the same as DeviceAdapterAlgorithm, it is /// made a separate template parameter to avoid a recursive dependence between /// DeviceAdapterAlgorithmGeneral.h and DeviceAdapterAlgorithm.h /// template struct DeviceAdapterAlgorithmGeneral { //-------------------------------------------------------------------------- // Get Execution Value // This method is used internally to get a single element from the execution // array. Might want to expose this and/or allow actual device adapter // implementations to provide one. private: template VTKM_CONT_EXPORT static T GetExecutionValue(const vtkm::cont::ArrayHandle &input, vtkm::Id index) { typedef vtkm::cont::ArrayHandle InputArrayType; typedef vtkm::cont::ArrayHandle OutputArrayType; OutputArrayType output; CopyKernel< typename InputArrayType::template ExecutionTypes::PortalConst, typename OutputArrayType::template ExecutionTypes::Portal> kernel(input.PrepareForInput(DeviceAdapterTag()), output.PrepareForOutput(1, DeviceAdapterTag()), index); DerivedAlgorithm::Schedule(kernel, 1); return output.GetPortalConstControl().Get(0); } //-------------------------------------------------------------------------- // Copy private: template struct CopyKernel { InputPortalType InputPortal; OutputPortalType OutputPortal; vtkm::Id InputOffset; vtkm::Id OutputOffset; VTKM_CONT_EXPORT CopyKernel(InputPortalType inputPortal, OutputPortalType outputPortal, vtkm::Id inputOffset = 0, vtkm::Id outputOffset = 0) : InputPortal(inputPortal), OutputPortal(outputPortal), InputOffset(inputOffset), OutputOffset(outputOffset) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const { this->OutputPortal.Set( index + this->OutputOffset, this->InputPortal.Get(index + this->InputOffset)); } VTKM_CONT_EXPORT void SetErrorMessageBuffer(const vtkm::exec::internal::ErrorMessageBuffer &) { } }; public: template VTKM_CONT_EXPORT static void Copy(const vtkm::cont::ArrayHandle &input, vtkm::cont::ArrayHandle &output) { vtkm::Id arraySize = input.GetNumberOfValues(); CopyKernel< typename vtkm::cont::ArrayHandle::template ExecutionTypes::PortalConst, typename vtkm::cont::ArrayHandle::template ExecutionTypes::Portal> kernel(input.PrepareForInput(DeviceAdapterTag()), output.PrepareForOutput(arraySize, DeviceAdapterTag())); DerivedAlgorithm::Schedule(kernel, arraySize); } //-------------------------------------------------------------------------- // Lower Bounds private: template struct LowerBoundsKernel { InputPortalType InputPortal; ValuesPortalType ValuesPortal; OutputPortalType OutputPortal; VTKM_CONT_EXPORT LowerBoundsKernel(InputPortalType inputPortal, ValuesPortalType valuesPortal, OutputPortalType outputPortal) : InputPortal(inputPortal), ValuesPortal(valuesPortal), OutputPortal(outputPortal) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const { // This method assumes that (1) InputPortalType can return working // iterators in the execution environment and that (2) methods not // specified with VTKM_EXEC_EXPORT (such as the STL algorithms) can be // called from the execution environment. Neither one of these is // necessarily true, but it is true for the current uses of this general // function and I don't want to compete with STL if I don't have to. typedef vtkm::cont::ArrayPortalToIterators InputIteratorsType; InputIteratorsType inputIterators(this->InputPortal); typename InputIteratorsType::IteratorType resultPos = std::lower_bound(inputIterators.GetBegin(), inputIterators.GetEnd(), this->ValuesPortal.Get(index)); vtkm::Id resultIndex = static_cast( std::distance(inputIterators.GetBegin(), resultPos)); this->OutputPortal.Set(index, resultIndex); } VTKM_CONT_EXPORT void SetErrorMessageBuffer(const vtkm::exec::internal::ErrorMessageBuffer &) { } }; template struct LowerBoundsComparisonKernel { InputPortalType InputPortal; ValuesPortalType ValuesPortal; OutputPortalType OutputPortal; BinaryCompare CompareFunctor; VTKM_CONT_EXPORT LowerBoundsComparisonKernel(InputPortalType inputPortal, ValuesPortalType valuesPortal, OutputPortalType outputPortal, BinaryCompare binary_compare) : InputPortal(inputPortal), ValuesPortal(valuesPortal), OutputPortal(outputPortal), CompareFunctor(binary_compare) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const { // This method assumes that (1) InputPortalType can return working // iterators in the execution environment and that (2) methods not // specified with VTKM_EXEC_EXPORT (such as the STL algorithms) can be // called from the execution environment. Neither one of these is // necessarily true, but it is true for the current uses of this general // function and I don't want to compete with STL if I don't have to. typedef vtkm::cont::ArrayPortalToIterators InputIteratorsType; InputIteratorsType inputIterators(this->InputPortal); typename InputIteratorsType::IteratorType resultPos = std::lower_bound(inputIterators.GetBegin(), inputIterators.GetEnd(), this->ValuesPortal.Get(index), this->CompareFunctor); vtkm::Id resultIndex = static_cast( std::distance(inputIterators.GetBegin(), resultPos)); this->OutputPortal.Set(index, resultIndex); } VTKM_CONT_EXPORT void SetErrorMessageBuffer(const vtkm::exec::internal::ErrorMessageBuffer &) { } }; public: template VTKM_CONT_EXPORT static void LowerBounds( const vtkm::cont::ArrayHandle &input, const vtkm::cont::ArrayHandle &values, vtkm::cont::ArrayHandle &output) { vtkm::Id arraySize = values.GetNumberOfValues(); LowerBoundsKernel< typename vtkm::cont::ArrayHandle::template ExecutionTypes::PortalConst, typename vtkm::cont::ArrayHandle::template ExecutionTypes::PortalConst, typename vtkm::cont::ArrayHandle::template ExecutionTypes::Portal> kernel(input.PrepareForInput(DeviceAdapterTag()), values.PrepareForInput(DeviceAdapterTag()), output.PrepareForOutput(arraySize, DeviceAdapterTag())); DerivedAlgorithm::Schedule(kernel, arraySize); } template VTKM_CONT_EXPORT static void LowerBounds( const vtkm::cont::ArrayHandle &input, const vtkm::cont::ArrayHandle &values, vtkm::cont::ArrayHandle &output, BinaryCompare binary_compare) { vtkm::Id arraySize = values.GetNumberOfValues(); LowerBoundsComparisonKernel< typename vtkm::cont::ArrayHandle::template ExecutionTypes::PortalConst, typename vtkm::cont::ArrayHandle::template ExecutionTypes::PortalConst, typename vtkm::cont::ArrayHandle::template ExecutionTypes::Portal, BinaryCompare> kernel(input.PrepareForInput(DeviceAdapterTag()), values.PrepareForInput(DeviceAdapterTag()), output.PrepareForOutput(arraySize, DeviceAdapterTag()), binary_compare); DerivedAlgorithm::Schedule(kernel, arraySize); } template VTKM_CONT_EXPORT static void LowerBounds( const vtkm::cont::ArrayHandle &input, vtkm::cont::ArrayHandle &values_output) { DeviceAdapterAlgorithmGeneral< DerivedAlgorithm,DeviceAdapterTag>::LowerBounds(input, values_output, values_output); } //-------------------------------------------------------------------------- // Reduce private: template struct ReduceKernel : vtkm::exec::FunctorBase { typedef typename ArrayType::template ExecutionTypes< DeviceAdapterTag> ExecutionTypes; typedef typename ExecutionTypes::PortalConst PortalConst; PortalConst Portal; BinaryFunctor BinaryOperator; vtkm::Id ArrayLength; VTKM_CONT_EXPORT ReduceKernel() : Portal(), BinaryOperator(), ArrayLength(0) { } VTKM_CONT_EXPORT ReduceKernel(const ArrayType &array, BinaryFunctor binary_functor) : Portal(array.PrepareForInput( DeviceAdapterTag() ) ), BinaryOperator(binary_functor), ArrayLength( array.GetNumberOfValues() ) { } VTKM_EXEC_EXPORT T operator()(vtkm::Id index) const { const vtkm::Id offset = index * ReduceWidth; //at least the first value access to the portal will be valid //only the rest could be invalid T partialSum = this->Portal.Get( offset ); if( offset + ReduceWidth >= this->ArrayLength ) { vtkm::Id currentIndex = offset + 1; while( currentIndex < this->ArrayLength) { partialSum = BinaryOperator(partialSum, this->Portal.Get(currentIndex)); ++currentIndex; } } else { //optimize the usecase where all values are valid and we don't //need to check that we might go out of bounds for(int i=1; i < ReduceWidth; ++i) { partialSum = BinaryOperator(partialSum, this->Portal.Get( offset + i ) ); } } return partialSum; } }; //-------------------------------------------------------------------------- // Reduce public: template VTKM_CONT_EXPORT static T Reduce( const vtkm::cont::ArrayHandle &input, T initialValue) { return DerivedAlgorithm::Reduce(input, initialValue, vtkm::internal::Add()); } template VTKM_CONT_EXPORT static T Reduce( const vtkm::cont::ArrayHandle &input, T initialValue, BinaryFunctor binary_functor) { //Crazy Idea: //We create a implicit array handle that wraps the input //array handle. The implicit functor is passed the input array handle, and //the number of elements it needs to sum. This way the implicit handle //acts as the first level reduction. Say for example reducing 16 values //at a time. // //Now that we have an implicit array that is 1/16 the length of full array //we can use scan inclusive to compute the final sum typedef ReduceKernel< 16, T, vtkm::cont::ArrayHandle, BinaryFunctor > ReduceKernelType; typedef vtkm::cont::ArrayHandleImplicit< T, ReduceKernelType > ReduceHandleType; typedef vtkm::cont::ArrayHandle< T, vtkm::cont::StorageTagBasic> TempArrayType; ReduceKernelType kernel(input, binary_functor); vtkm::Id length = (input.GetNumberOfValues() / 16); length += (input.GetNumberOfValues() % 16 == 0) ? 0 : 1; ReduceHandleType reduced = vtkm::cont::make_ArrayHandleImplicit(kernel, length); TempArrayType inclusiveScanStorage; T scanResult = DerivedAlgorithm::ScanInclusive(reduced, inclusiveScanStorage, binary_functor); return binary_functor(initialValue, scanResult); } //-------------------------------------------------------------------------- // Reduce By Key private: struct ReduceKeySeriesStates { bool fStart; // START of a segment bool fEnd; // END of a segment ReduceKeySeriesStates(bool start=false, bool end=false) : fStart(start), fEnd(end) {} }; template struct ReduceStencilGeneration : vtkm::exec::FunctorBase { typedef typename vtkm::cont::ArrayHandle< ReduceKeySeriesStates >::template ExecutionTypes ::Portal KeyStatePortalType; InputPortalType Input; KeyStatePortalType KeyState; VTKM_CONT_EXPORT ReduceStencilGeneration(const InputPortalType &input, const KeyStatePortalType &kstate) : Input(input), KeyState(kstate) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id centerIndex) const { typedef typename InputPortalType::ValueType ValueType; typedef typename KeyStatePortalType::ValueType KeyStateType; const vtkm::Id leftIndex = centerIndex - 1; const vtkm::Id rightIndex = centerIndex + 1; //we need to determine which of three states this //index is. It can be: // 1. Middle of a set of equivalent keys. // 2. Start of a set of equivalent keys. // 3. End of a set of equivalent keys. // 4. Both the start and end of a set of keys //we don't have to worry about an array of length 1, as //the calling code handles that use case if(centerIndex == 0) { //this means we are at the start of the array //means we are automatically START //just need to check if we are END const ValueType centerValue = this->Input.Get(centerIndex); const ValueType rightValue = this->Input.Get(rightIndex); const KeyStateType state = ReduceKeySeriesStates(true, rightValue != centerValue); this->KeyState.Set(centerIndex, state); } else if(rightIndex == this->Input.GetNumberOfValues()) { //this means we are at the end, so we are at least END //just need to check if we are START const ValueType centerValue = this->Input.Get(centerIndex); const ValueType leftValue = this->Input.Get(leftIndex); const KeyStateType state = ReduceKeySeriesStates(leftValue != centerValue, true); this->KeyState.Set(centerIndex, state); } else { const ValueType centerValue = this->Input.Get(centerIndex); const bool leftMatches(this->Input.Get(leftIndex) == centerValue); const bool rightMatches(this->Input.Get(rightIndex) == centerValue); //assume it is the middle, and check for the other use-case KeyStateType state = ReduceKeySeriesStates(!leftMatches, !rightMatches); this->KeyState.Set(centerIndex, state); } } }; template struct ReduceByKeyAdd { BinaryFunctor BinaryOperator; ReduceByKeyAdd(BinaryFunctor binary_functor): BinaryOperator( binary_functor ) { } template vtkm::Pair operator()(const vtkm::Pair& a, const vtkm::Pair& b) const { typedef vtkm::Pair ReturnType; //need too handle how we are going to add two numbers together //based on the keyStates that they have // Make it work for parallel inclusive scan. Will end up with all start bits = 1 // the following logic should change if you use a different parallel scan algorithm. if (!b.second.fStart) { // if b is not START, then it's safe to sum a & b. // Propagate a's start flag to b // so that later when b's START bit is set, it means there must exists a START between a and b return ReturnType(this->BinaryOperator(a.first , b.first), ReduceKeySeriesStates(a.second.fStart, b.second.fEnd)); } return b; } }; struct ReduceByKeyUnaryStencilOp { bool operator()(ReduceKeySeriesStates keySeriesState) const { return keySeriesState.fEnd; } }; public: template VTKM_CONT_EXPORT static void ReduceByKey( const vtkm::cont::ArrayHandle &keys, const vtkm::cont::ArrayHandle &values, vtkm::cont::ArrayHandle &keys_output, vtkm::cont::ArrayHandle &values_output, BinaryFunctor binary_functor) { VTKM_ASSERT_CONT(keys.GetNumberOfValues() == values.GetNumberOfValues()); const vtkm::Id numberOfKeys = keys.GetNumberOfValues(); if(numberOfKeys <= 1) { //we only have a single key/value so that is our output DerivedAlgorithm::Copy(keys, keys_output); DerivedAlgorithm::Copy(values, values_output); return; } //we need to determine based on the keys what is the keystate for //each key. The states are start, middle, end of a series and the special //state start and end of a series vtkm::cont::ArrayHandle< ReduceKeySeriesStates > keystate; { typedef typename vtkm::cont::ArrayHandle::template ExecutionTypes ::PortalConst InputPortalType; typedef typename vtkm::cont::ArrayHandle< ReduceKeySeriesStates >::template ExecutionTypes ::Portal KeyStatePortalType; InputPortalType inputPortal = keys.PrepareForInput(DeviceAdapterTag()); KeyStatePortalType keyStatePortal = keystate.PrepareForOutput(numberOfKeys, DeviceAdapterTag()); ReduceStencilGeneration kernel(inputPortal, keyStatePortal); DerivedAlgorithm::Schedule(kernel, numberOfKeys); } //next step is we need to reduce the values for each key. This is done //by running an inclusive scan over the values array using the stencil. // // this inclusive scan will write out two values, the first being // the value summed currently, the second being 0 or 1, with 1 being used // when this is a value of a key we need to write ( END or START_AND_END) { typedef vtkm::cont::ArrayHandle ValueInHandleType; typedef vtkm::cont::ArrayHandle ValueOutHandleType; typedef vtkm::cont::ArrayHandle< ReduceKeySeriesStates> StencilHandleType; typedef vtkm::cont::ArrayHandleZip ZipInHandleType; typedef vtkm::cont::ArrayHandleZip ZipOutHandleType; StencilHandleType stencil; ValueOutHandleType reducedValues; ZipInHandleType scanInput( values, keystate); ZipOutHandleType scanOutput( reducedValues, stencil); DerivedAlgorithm::ScanInclusive(scanInput, scanOutput, ReduceByKeyAdd(binary_functor) ); //at this point we are done with keystate, so free the memory keystate.ReleaseResources(); // all we need know is an efficient way of doing the write back to the // reduced global memory. this is done by using StreamCompact with the // stencil and values we just created with the inclusive scan DerivedAlgorithm::StreamCompact( reducedValues, stencil, values_output, ReduceByKeyUnaryStencilOp()); } //release all temporary memory //find all the unique keys DerivedAlgorithm::Copy(keys,keys_output); DerivedAlgorithm::Unique(keys_output); } //-------------------------------------------------------------------------- // Scan Exclusive private: template struct SetConstantKernel { typedef typename PortalType::ValueType ValueType; PortalType Portal; ValueType Value; VTKM_CONT_EXPORT SetConstantKernel(const PortalType &portal, ValueType value) : Portal(portal), Value(value) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const { this->Portal.Set(index, this->Value); } VTKM_CONT_EXPORT void SetErrorMessageBuffer(const vtkm::exec::internal::ErrorMessageBuffer &) { } }; public: template VTKM_CONT_EXPORT static T ScanExclusive( const vtkm::cont::ArrayHandle &input, vtkm::cont::ArrayHandle& output) { typedef vtkm::cont::ArrayHandle TempArrayType; typedef vtkm::cont::ArrayHandle OutputArrayType; TempArrayType inclusiveScan; T result = DerivedAlgorithm::ScanInclusive(input, inclusiveScan); vtkm::Id numValues = inclusiveScan.GetNumberOfValues(); if (numValues < 1) { return result; } typedef typename TempArrayType::template ExecutionTypes ::PortalConst SrcPortalType; SrcPortalType srcPortal = inclusiveScan.PrepareForInput(DeviceAdapterTag()); typedef typename OutputArrayType::template ExecutionTypes ::Portal DestPortalType; DestPortalType destPortal = output.PrepareForOutput(numValues, DeviceAdapterTag()); // Set first value in output (always 0). DerivedAlgorithm::Schedule( SetConstantKernel( destPortal, vtkm::TypeTraits::ZeroInitialization()), 1); // Shift remaining values over by one. DerivedAlgorithm::Schedule( CopyKernel(srcPortal, destPortal, 0, 1), numValues - 1); return result; } //-------------------------------------------------------------------------- // Scan Inclusive private: template struct ScanKernel : vtkm::exec::FunctorBase { PortalType Portal; BinaryFunctor BinaryOperator; vtkm::Id Stride; vtkm::Id Offset; vtkm::Id Distance; VTKM_CONT_EXPORT ScanKernel(const PortalType &portal, BinaryFunctor binary_functor, vtkm::Id stride, vtkm::Id offset) : Portal(portal), BinaryOperator(binary_functor), Stride(stride), Offset(offset), Distance(stride/2) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const { typedef typename PortalType::ValueType ValueType; vtkm::Id leftIndex = this->Offset + index*this->Stride; vtkm::Id rightIndex = leftIndex + this->Distance; if (rightIndex < this->Portal.GetNumberOfValues()) { ValueType leftValue = this->Portal.Get(leftIndex); ValueType rightValue = this->Portal.Get(rightIndex); this->Portal.Set(rightIndex, BinaryOperator(leftValue,rightValue) ); } } }; public: template VTKM_CONT_EXPORT static T ScanInclusive( const vtkm::cont::ArrayHandle &input, vtkm::cont::ArrayHandle& output) { return DerivedAlgorithm::ScanInclusive(input, output, vtkm::internal::Add()); } template VTKM_CONT_EXPORT static T ScanInclusive( const vtkm::cont::ArrayHandle &input, vtkm::cont::ArrayHandle& output, BinaryFunctor binary_functor) { typedef typename vtkm::cont::ArrayHandle ::template ExecutionTypes::Portal PortalType; typedef ScanKernel ScanKernelType; DerivedAlgorithm::Copy(input, output); vtkm::Id numValues = output.GetNumberOfValues(); if (numValues < 1) { return output.GetPortalConstControl().Get(0); } PortalType portal = output.PrepareForInPlace(DeviceAdapterTag()); vtkm::Id stride; for (stride = 2; stride-1 < numValues; stride *= 2) { ScanKernelType kernel(portal, binary_functor, stride, stride/2 - 1); DerivedAlgorithm::Schedule(kernel, numValues/stride); } // Do reverse operation on odd indices. Start at stride we were just at. for (stride /= 2; stride > 1; stride /= 2) { ScanKernelType kernel(portal, binary_functor, stride, stride - 1); DerivedAlgorithm::Schedule(kernel, numValues/stride); } return GetExecutionValue(output, numValues-1); } //-------------------------------------------------------------------------- // Sort private: template struct BitonicSortMergeKernel : vtkm::exec::FunctorBase { PortalType Portal; BinaryCompare Compare; vtkm::Id GroupSize; VTKM_CONT_EXPORT BitonicSortMergeKernel(const PortalType &portal, const BinaryCompare &compare, vtkm::Id groupSize) : Portal(portal), Compare(compare), GroupSize(groupSize) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const { typedef typename PortalType::ValueType ValueType; vtkm::Id groupIndex = index%this->GroupSize; vtkm::Id blockSize = 2*this->GroupSize; vtkm::Id blockIndex = index/this->GroupSize; vtkm::Id lowIndex = blockIndex * blockSize + groupIndex; vtkm::Id highIndex = lowIndex + this->GroupSize; if (highIndex < this->Portal.GetNumberOfValues()) { ValueType lowValue = this->Portal.Get(lowIndex); ValueType highValue = this->Portal.Get(highIndex); if (this->Compare(highValue, lowValue)) { this->Portal.Set(highIndex, lowValue); this->Portal.Set(lowIndex, highValue); } } } }; template struct BitonicSortCrossoverKernel : vtkm::exec::FunctorBase { PortalType Portal; BinaryCompare Compare; vtkm::Id GroupSize; VTKM_CONT_EXPORT BitonicSortCrossoverKernel(const PortalType &portal, const BinaryCompare &compare, vtkm::Id groupSize) : Portal(portal), Compare(compare), GroupSize(groupSize) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const { typedef typename PortalType::ValueType ValueType; vtkm::Id groupIndex = index%this->GroupSize; vtkm::Id blockSize = 2*this->GroupSize; vtkm::Id blockIndex = index/this->GroupSize; vtkm::Id lowIndex = blockIndex*blockSize + groupIndex; vtkm::Id highIndex = blockIndex*blockSize + (blockSize - groupIndex - 1); if (highIndex < this->Portal.GetNumberOfValues()) { ValueType lowValue = this->Portal.Get(lowIndex); ValueType highValue = this->Portal.Get(highIndex); if (this->Compare(highValue, lowValue)) { this->Portal.Set(highIndex, lowValue); this->Portal.Set(lowIndex, highValue); } } } }; struct DefaultCompareFunctor { template VTKM_EXEC_EXPORT bool operator()(const T& first, const T& second) const { return first < second; } }; public: template VTKM_CONT_EXPORT static void Sort( vtkm::cont::ArrayHandle &values, BinaryCompare binary_compare) { typedef typename vtkm::cont::ArrayHandle ArrayType; typedef typename ArrayType::template ExecutionTypes ::Portal PortalType; vtkm::Id numValues = values.GetNumberOfValues(); if (numValues < 2) { return; } PortalType portal = values.PrepareForInPlace(DeviceAdapterTag()); vtkm::Id numThreads = 1; while (numThreads < numValues) { numThreads *= 2; } numThreads /= 2; typedef BitonicSortMergeKernel MergeKernel; typedef BitonicSortCrossoverKernel CrossoverKernel; for (vtkm::Id crossoverSize = 1; crossoverSize < numValues; crossoverSize *= 2) { DerivedAlgorithm::Schedule(CrossoverKernel(portal,binary_compare,crossoverSize), numThreads); for (vtkm::Id mergeSize = crossoverSize/2; mergeSize > 0; mergeSize /= 2) { DerivedAlgorithm::Schedule(MergeKernel(portal,binary_compare,mergeSize), numThreads); } } } template VTKM_CONT_EXPORT static void Sort( vtkm::cont::ArrayHandle &values) { DerivedAlgorithm::Sort(values, DefaultCompareFunctor()); } //-------------------------------------------------------------------------- // Sort by Key protected: template struct KeyCompare { KeyCompare(): CompareFunctor() {} explicit KeyCompare(BinaryCompare c): CompareFunctor(c) {} VTKM_EXEC_EXPORT bool operator()(const vtkm::Pair& a, const vtkm::Pair& b) const { return CompareFunctor(a.first,b.first); } private: BinaryCompare CompareFunctor; }; public: template VTKM_CONT_EXPORT static void SortByKey( vtkm::cont::ArrayHandle &keys, vtkm::cont::ArrayHandle &values) { //combine the keys and values into a ZipArrayHandle //we than need to specify a custom compare function wrapper //that only checks for key side of the pair, using a custom compare functor. typedef vtkm::cont::ArrayHandle KeyType; typedef vtkm::cont::ArrayHandle ValueType; typedef vtkm::cont::ArrayHandleZip ZipHandleType; ZipHandleType zipHandle = vtkm::cont::make_ArrayHandleZip(keys,values); DerivedAlgorithm::Sort(zipHandle,KeyCompare()); } template VTKM_CONT_EXPORT static void SortByKey( vtkm::cont::ArrayHandle &keys, vtkm::cont::ArrayHandle &values, BinaryCompare binary_compare) { //combine the keys and values into a ZipArrayHandle //we than need to specify a custom compare function wrapper //that only checks for key side of the pair, using the custom compare //functor that the user passed in typedef vtkm::cont::ArrayHandle KeyType; typedef vtkm::cont::ArrayHandle ValueType; typedef vtkm::cont::ArrayHandleZip ZipHandleType; ZipHandleType zipHandle = vtkm::cont::make_ArrayHandleZip(keys,values); DerivedAlgorithm::Sort(zipHandle,KeyCompare(binary_compare)); } //-------------------------------------------------------------------------- // Stream Compact private: template struct StencilToIndexFlagKernel { typedef typename StencilPortalType::ValueType StencilValueType; StencilPortalType StencilPortal; OutputPortalType OutputPortal; UnaryPredicate Predicate; VTKM_CONT_EXPORT StencilToIndexFlagKernel(StencilPortalType stencilPortal, OutputPortalType outputPortal, UnaryPredicate unary_predicate) : StencilPortal(stencilPortal), OutputPortal(outputPortal), Predicate(unary_predicate) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const { StencilValueType value = this->StencilPortal.Get(index); this->OutputPortal.Set(index, this->Predicate(value) ? 1 : 0); } VTKM_CONT_EXPORT void SetErrorMessageBuffer(const vtkm::exec::internal::ErrorMessageBuffer &) { } }; template struct CopyIfKernel { InputPortalType InputPortal; StencilPortalType StencilPortal; IndexPortalType IndexPortal; OutputPortalType OutputPortal; PredicateOperator Predicate; VTKM_CONT_EXPORT CopyIfKernel(InputPortalType inputPortal, StencilPortalType stencilPortal, IndexPortalType indexPortal, OutputPortalType outputPortal, PredicateOperator unary_predicate) : InputPortal(inputPortal), StencilPortal(stencilPortal), IndexPortal(indexPortal), OutputPortal(outputPortal), Predicate(unary_predicate) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const { typedef typename StencilPortalType::ValueType StencilValueType; StencilValueType stencilValue = this->StencilPortal.Get(index); if (Predicate(stencilValue)) { vtkm::Id outputIndex = this->IndexPortal.Get(index); typedef typename OutputPortalType::ValueType OutputValueType; OutputValueType value = this->InputPortal.Get(index); this->OutputPortal.Set(outputIndex, value); } } VTKM_CONT_EXPORT void SetErrorMessageBuffer(const vtkm::exec::internal::ErrorMessageBuffer &) { } }; public: template VTKM_CONT_EXPORT static void StreamCompact( const vtkm::cont::ArrayHandle& input, const vtkm::cont::ArrayHandle& stencil, vtkm::cont::ArrayHandle& output, UnaryPredicate unary_predicate) { VTKM_ASSERT_CONT(input.GetNumberOfValues() == stencil.GetNumberOfValues()); vtkm::Id arrayLength = stencil.GetNumberOfValues(); typedef vtkm::cont::ArrayHandle< vtkm::Id, vtkm::cont::StorageTagBasic> IndexArrayType; IndexArrayType indices; typedef typename vtkm::cont::ArrayHandle ::template ExecutionTypes::PortalConst StencilPortalType; StencilPortalType stencilPortal = stencil.PrepareForInput(DeviceAdapterTag()); typedef typename IndexArrayType ::template ExecutionTypes::Portal IndexPortalType; IndexPortalType indexPortal = indices.PrepareForOutput(arrayLength, DeviceAdapterTag()); StencilToIndexFlagKernel< StencilPortalType, IndexPortalType, UnaryPredicate> indexKernel(stencilPortal, indexPortal, unary_predicate); DerivedAlgorithm::Schedule(indexKernel, arrayLength); vtkm::Id outArrayLength = DerivedAlgorithm::ScanExclusive(indices, indices); typedef typename vtkm::cont::ArrayHandle ::template ExecutionTypes::PortalConst InputPortalType; InputPortalType inputPortal = input.PrepareForInput(DeviceAdapterTag()); typedef typename vtkm::cont::ArrayHandle ::template ExecutionTypes::Portal OutputPortalType; OutputPortalType outputPortal = output.PrepareForOutput(outArrayLength, DeviceAdapterTag()); CopyIfKernel< InputPortalType, StencilPortalType, IndexPortalType, OutputPortalType, UnaryPredicate> copyKernel(inputPortal, stencilPortal, indexPortal, outputPortal, unary_predicate); DerivedAlgorithm::Schedule(copyKernel, arrayLength); } template VTKM_CONT_EXPORT static void StreamCompact( const vtkm::cont::ArrayHandle& input, const vtkm::cont::ArrayHandle& stencil, vtkm::cont::ArrayHandle& output) { ::vtkm::not_default_constructor unary_predicate; DerivedAlgorithm::StreamCompact(input, stencil, output, unary_predicate); } template VTKM_CONT_EXPORT static void StreamCompact( const vtkm::cont::ArrayHandle &stencil, vtkm::cont::ArrayHandle &output) { typedef vtkm::cont::ArrayHandleCounting CountingHandleType; CountingHandleType input = vtkm::cont::make_ArrayHandleCounting(vtkm::Id(0), stencil.GetNumberOfValues()); DerivedAlgorithm::StreamCompact(input, stencil, output); } //-------------------------------------------------------------------------- // Unique private: template struct ClassifyUniqueKernel { InputPortalType InputPortal; StencilPortalType StencilPortal; VTKM_CONT_EXPORT ClassifyUniqueKernel(InputPortalType inputPortal, StencilPortalType stencilPortal) : InputPortal(inputPortal), StencilPortal(stencilPortal) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const { typedef typename StencilPortalType::ValueType ValueType; if (index == 0) { // Always copy first value. this->StencilPortal.Set(index, ValueType(1)); } else { ValueType flag = ValueType(this->InputPortal.Get(index-1) != this->InputPortal.Get(index)); this->StencilPortal.Set(index, flag); } } VTKM_CONT_EXPORT void SetErrorMessageBuffer(const vtkm::exec::internal::ErrorMessageBuffer &) { } }; template struct ClassifyUniqueComparisonKernel { InputPortalType InputPortal; StencilPortalType StencilPortal; BinaryCompare CompareFunctor; VTKM_CONT_EXPORT ClassifyUniqueComparisonKernel(InputPortalType inputPortal, StencilPortalType stencilPortal, BinaryCompare binary_compare): InputPortal(inputPortal), StencilPortal(stencilPortal), CompareFunctor(binary_compare) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const { typedef typename StencilPortalType::ValueType ValueType; if (index == 0) { // Always copy first value. this->StencilPortal.Set(index, ValueType(1)); } else { //comparison predicate returns true when they match const bool same = !(this->CompareFunctor(this->InputPortal.Get(index-1), this->InputPortal.Get(index))); ValueType flag = ValueType(same); this->StencilPortal.Set(index, flag); } } VTKM_CONT_EXPORT void SetErrorMessageBuffer(const vtkm::exec::internal::ErrorMessageBuffer &) { } }; public: template VTKM_CONT_EXPORT static void Unique( vtkm::cont::ArrayHandle &values) { vtkm::cont::ArrayHandle stencilArray; vtkm::Id inputSize = values.GetNumberOfValues(); ClassifyUniqueKernel< typename vtkm::cont::ArrayHandle::template ExecutionTypes::PortalConst, typename vtkm::cont::ArrayHandle::template ExecutionTypes::Portal> classifyKernel(values.PrepareForInput(DeviceAdapterTag()), stencilArray.PrepareForOutput(inputSize, DeviceAdapterTag())); DerivedAlgorithm::Schedule(classifyKernel, inputSize); vtkm::cont::ArrayHandle outputArray; DerivedAlgorithm::StreamCompact(values, stencilArray, outputArray); DerivedAlgorithm::Copy(outputArray, values); } template VTKM_CONT_EXPORT static void Unique( vtkm::cont::ArrayHandle &values, BinaryCompare binary_compare) { vtkm::cont::ArrayHandle stencilArray; vtkm::Id inputSize = values.GetNumberOfValues(); ClassifyUniqueComparisonKernel< typename vtkm::cont::ArrayHandle::template ExecutionTypes::PortalConst, typename vtkm::cont::ArrayHandle::template ExecutionTypes::Portal, BinaryCompare> classifyKernel(values.PrepareForInput(DeviceAdapterTag()), stencilArray.PrepareForOutput(inputSize, DeviceAdapterTag()), binary_compare); DerivedAlgorithm::Schedule(classifyKernel, inputSize); vtkm::cont::ArrayHandle outputArray; DerivedAlgorithm::StreamCompact(values, stencilArray, outputArray); DerivedAlgorithm::Copy(outputArray, values); } //-------------------------------------------------------------------------- // Upper bounds private: template struct UpperBoundsKernel { InputPortalType InputPortal; ValuesPortalType ValuesPortal; OutputPortalType OutputPortal; VTKM_CONT_EXPORT UpperBoundsKernel(InputPortalType inputPortal, ValuesPortalType valuesPortal, OutputPortalType outputPortal) : InputPortal(inputPortal), ValuesPortal(valuesPortal), OutputPortal(outputPortal) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const { // This method assumes that (1) InputPortalType can return working // iterators in the execution environment and that (2) methods not // specified with VTKM_EXEC_EXPORT (such as the STL algorithms) can be // called from the execution environment. Neither one of these is // necessarily true, but it is true for the current uses of this general // function and I don't want to compete with STL if I don't have to. typedef vtkm::cont::ArrayPortalToIterators InputIteratorsType; InputIteratorsType inputIterators(this->InputPortal); typename InputIteratorsType::IteratorType resultPos = std::upper_bound(inputIterators.GetBegin(), inputIterators.GetEnd(), this->ValuesPortal.Get(index)); vtkm::Id resultIndex = static_cast( std::distance(inputIterators.GetBegin(), resultPos)); this->OutputPortal.Set(index, resultIndex); } VTKM_CONT_EXPORT void SetErrorMessageBuffer(const vtkm::exec::internal::ErrorMessageBuffer &) { } }; template struct UpperBoundsKernelComparisonKernel { InputPortalType InputPortal; ValuesPortalType ValuesPortal; OutputPortalType OutputPortal; BinaryCompare CompareFunctor; VTKM_CONT_EXPORT UpperBoundsKernelComparisonKernel(InputPortalType inputPortal, ValuesPortalType valuesPortal, OutputPortalType outputPortal, BinaryCompare binary_compare) : InputPortal(inputPortal), ValuesPortal(valuesPortal), OutputPortal(outputPortal), CompareFunctor(binary_compare) { } VTKM_EXEC_EXPORT void operator()(vtkm::Id index) const { // This method assumes that (1) InputPortalType can return working // iterators in the execution environment and that (2) methods not // specified with VTKM_EXEC_EXPORT (such as the STL algorithms) can be // called from the execution environment. Neither one of these is // necessarily true, but it is true for the current uses of this general // function and I don't want to compete with STL if I don't have to. typedef vtkm::cont::ArrayPortalToIterators InputIteratorsType; InputIteratorsType inputIterators(this->InputPortal); typename InputIteratorsType::IteratorType resultPos = std::upper_bound(inputIterators.GetBegin(), inputIterators.GetEnd(), this->ValuesPortal.Get(index), this->CompareFunctor); vtkm::Id resultIndex = static_cast( std::distance(inputIterators.GetBegin(), resultPos)); this->OutputPortal.Set(index, resultIndex); } VTKM_CONT_EXPORT void SetErrorMessageBuffer(const vtkm::exec::internal::ErrorMessageBuffer &) { } }; public: template VTKM_CONT_EXPORT static void UpperBounds( const vtkm::cont::ArrayHandle &input, const vtkm::cont::ArrayHandle &values, vtkm::cont::ArrayHandle &output) { vtkm::Id arraySize = values.GetNumberOfValues(); UpperBoundsKernel< typename vtkm::cont::ArrayHandle::template ExecutionTypes::PortalConst, typename vtkm::cont::ArrayHandle::template ExecutionTypes::PortalConst, typename vtkm::cont::ArrayHandle::template ExecutionTypes::Portal> kernel(input.PrepareForInput(DeviceAdapterTag()), values.PrepareForInput(DeviceAdapterTag()), output.PrepareForOutput(arraySize, DeviceAdapterTag())); DerivedAlgorithm::Schedule(kernel, arraySize); } template VTKM_CONT_EXPORT static void UpperBounds( const vtkm::cont::ArrayHandle &input, const vtkm::cont::ArrayHandle &values, vtkm::cont::ArrayHandle &output, BinaryCompare binary_compare) { vtkm::Id arraySize = values.GetNumberOfValues(); UpperBoundsKernelComparisonKernel< typename vtkm::cont::ArrayHandle::template ExecutionTypes::PortalConst, typename vtkm::cont::ArrayHandle::template ExecutionTypes::PortalConst, typename vtkm::cont::ArrayHandle::template ExecutionTypes::Portal, BinaryCompare> kernel(input.PrepareForInput(DeviceAdapterTag()), values.PrepareForInput(DeviceAdapterTag()), output.PrepareForOutput(arraySize, DeviceAdapterTag()), binary_compare); DerivedAlgorithm::Schedule(kernel, arraySize); } template VTKM_CONT_EXPORT static void UpperBounds( const vtkm::cont::ArrayHandle &input, vtkm::cont::ArrayHandle &values_output) { DeviceAdapterAlgorithmGeneral::UpperBounds(input, values_output, values_output); } }; } } } // namespace vtkm::cont::internal #endif //vtk_m_cont_internal_DeviceAdapterAlgorithmGeneral_h