mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-09-16 17:22:55 +00:00
Generalize the TBB radix sort implementation.
The core algorithm will be shared by OpenMP.
This commit is contained in:
parent
d602784348
commit
e621b6ba3c
@ -37,6 +37,9 @@ set(headers
|
||||
DynamicTransform.h
|
||||
FunctorsGeneral.h
|
||||
IteratorFromArrayPortal.h
|
||||
KXSort.h
|
||||
ParallelRadixSort.h
|
||||
ParallelRadixSortInterface.h
|
||||
SimplePolymorphicContainer.h
|
||||
StorageError.h
|
||||
VirtualObjectTransfer.h
|
||||
|
@ -1,3 +1,25 @@
|
||||
//=============================================================================
|
||||
//
|
||||
// Copyright (c) Kitware, Inc.
|
||||
// All rights reserved.
|
||||
// See LICENSE.txt for details.
|
||||
//
|
||||
// This software is distributed WITHOUT ANY WARRANTY; without even
|
||||
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
|
||||
// PURPOSE. See the above copyright notice for more information.
|
||||
//
|
||||
// Copyright 2018 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
|
||||
// Copyright 2018 UT-Battelle, LLC.
|
||||
// Copyright 2018 Los Alamos National Security.
|
||||
//
|
||||
// Under the terms of Contract DE-NA0003525 with NTESS,
|
||||
// the U.S. Government retains certain rights in this software.
|
||||
// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
|
||||
// Laboratory (LANL), the U.S. Government retains certain rights in
|
||||
// this software.
|
||||
//
|
||||
//=============================================================================
|
||||
|
||||
/* The MIT License
|
||||
Copyright (c) 2016 Dinghua Li <voutcn@gmail.com>
|
||||
|
1069
vtkm/cont/internal/ParallelRadixSort.h
Normal file
1069
vtkm/cont/internal/ParallelRadixSort.h
Normal file
File diff suppressed because it is too large
Load Diff
166
vtkm/cont/internal/ParallelRadixSortInterface.h
Normal file
166
vtkm/cont/internal/ParallelRadixSortInterface.h
Normal file
@ -0,0 +1,166 @@
|
||||
//============================================================================
|
||||
// Copyright (c) Kitware, Inc.
|
||||
// All rights reserved.
|
||||
// See LICENSE.txt for details.
|
||||
// This software is distributed WITHOUT ANY WARRANTY; without even
|
||||
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
|
||||
// PURPOSE. See the above copyright notice for more information.
|
||||
//
|
||||
// Copyright 2017 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
|
||||
// Copyright 2017 UT-Battelle, LLC.
|
||||
// Copyright 2017 Los Alamos National Security.
|
||||
//
|
||||
// Under the terms of Contract DE-NA0003525 with NTESS,
|
||||
// the U.S. Government retains certain rights in this software.
|
||||
//
|
||||
// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
|
||||
// Laboratory (LANL), the U.S. Government retains certain rights in
|
||||
// this software.
|
||||
//============================================================================
|
||||
|
||||
#ifndef vtk_m_cont_internal_ParallelRadixSortInterface_h
|
||||
#define vtk_m_cont_internal_ParallelRadixSortInterface_h
|
||||
|
||||
#include <vtkm/BinaryPredicates.h>
|
||||
#include <vtkm/cont/ArrayHandle.h>
|
||||
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
|
||||
namespace vtkm
|
||||
{
|
||||
namespace cont
|
||||
{
|
||||
namespace internal
|
||||
{
|
||||
namespace radix
|
||||
{
|
||||
|
||||
const size_t MIN_BYTES_FOR_PARALLEL = 400000;
|
||||
const size_t BYTES_FOR_MAX_PARALLELISM = 4000000;
|
||||
|
||||
struct RadixSortTag
|
||||
{
|
||||
};
|
||||
|
||||
struct PSortTag
|
||||
{
|
||||
};
|
||||
|
||||
// Detect supported functors for radix sort:
|
||||
template <typename T>
|
||||
struct is_valid_compare_type : std::integral_constant<bool, false>
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct is_valid_compare_type<std::less<T>> : std::integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct is_valid_compare_type<std::greater<T>> : std::integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct is_valid_compare_type<vtkm::SortLess> : std::integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct is_valid_compare_type<vtkm::SortGreater> : std::integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
|
||||
// Convert vtkm::Sort[Less|Greater] to the std:: equivalents:
|
||||
template <typename BComp, typename T>
|
||||
BComp&& get_std_compare(BComp&& b, T&&)
|
||||
{
|
||||
return std::forward<BComp>(b);
|
||||
}
|
||||
template <typename T>
|
||||
std::less<T> get_std_compare(vtkm::SortLess, T&&)
|
||||
{
|
||||
return std::less<T>{};
|
||||
}
|
||||
template <typename T>
|
||||
std::greater<T> get_std_compare(vtkm::SortGreater, T&&)
|
||||
{
|
||||
return std::greater<T>{};
|
||||
}
|
||||
|
||||
// Determine if radix sort can be used for a given ValueType, StorageType, and
|
||||
// comparison functor.
|
||||
template <typename T, typename StorageTag, typename BinaryCompare>
|
||||
struct sort_tag_type
|
||||
{
|
||||
using type = PSortTag;
|
||||
};
|
||||
template <typename T, typename BinaryCompare>
|
||||
struct sort_tag_type<T, vtkm::cont::StorageTagBasic, BinaryCompare>
|
||||
{
|
||||
using PrimT = std::is_arithmetic<T>;
|
||||
using LongDT = std::is_same<T, long double>;
|
||||
using BComp = is_valid_compare_type<BinaryCompare>;
|
||||
using type = typename std::conditional<PrimT::value && BComp::value && !LongDT::value,
|
||||
RadixSortTag,
|
||||
PSortTag>::type;
|
||||
};
|
||||
|
||||
template <typename KeyType,
|
||||
typename ValueType,
|
||||
typename KeyStorageTagType,
|
||||
typename ValueStorageTagType,
|
||||
class BinaryCompare>
|
||||
struct sortbykey_tag_type
|
||||
{
|
||||
using type = PSortTag;
|
||||
};
|
||||
template <typename KeyType, typename ValueType, class BinaryCompare>
|
||||
struct sortbykey_tag_type<KeyType,
|
||||
ValueType,
|
||||
vtkm::cont::StorageTagBasic,
|
||||
vtkm::cont::StorageTagBasic,
|
||||
BinaryCompare>
|
||||
{
|
||||
using PrimKey = std::is_arithmetic<KeyType>;
|
||||
using PrimValue = std::is_arithmetic<ValueType>;
|
||||
using LongDKey = std::is_same<KeyType, long double>;
|
||||
using BComp = is_valid_compare_type<BinaryCompare>;
|
||||
using type = typename std::conditional<PrimKey::value && PrimValue::value && BComp::value &&
|
||||
!LongDKey::value,
|
||||
RadixSortTag,
|
||||
PSortTag>::type;
|
||||
};
|
||||
|
||||
#define VTKM_INTERNAL_RADIX_SORT_DECLARE(key_type) \
|
||||
VTKM_CONT_EXPORT void parallel_radix_sort( \
|
||||
key_type* data, size_t num_elems, const std::greater<key_type>& comp); \
|
||||
VTKM_CONT_EXPORT void parallel_radix_sort( \
|
||||
key_type* data, size_t num_elems, const std::less<key_type>& comp); \
|
||||
VTKM_CONT_EXPORT void parallel_radix_sort_key_values( \
|
||||
key_type* keys, vtkm::Id* vals, size_t num_elems, const std::greater<key_type>& comp); \
|
||||
VTKM_CONT_EXPORT void parallel_radix_sort_key_values( \
|
||||
key_type* keys, vtkm::Id* vals, size_t num_elems, const std::less<key_type>& comp);
|
||||
|
||||
// Generate radix sort interfaces for key and key value sorts.
|
||||
#define VTKM_DECLARE_RADIX_SORT() \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(short int) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(unsigned short int) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(int) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(unsigned int) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(long int) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(unsigned long int) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(long long int) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(unsigned long long int) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(unsigned char) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(signed char) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(char) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(char16_t) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(char32_t) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(wchar_t) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(float) \
|
||||
VTKM_INTERNAL_RADIX_SORT_DECLARE(double)
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end vtkm::cont::internal::radix
|
||||
|
||||
#endif // vtk_m_cont_internal_ParallelRadixSortInterface_h
|
@ -251,21 +251,21 @@ public:
|
||||
{
|
||||
//this is required to get sort to work with zip handles
|
||||
std::less<T> lessOp;
|
||||
vtkm::cont::tbb::internal::parallel_sort(values, lessOp);
|
||||
vtkm::cont::tbb::sort::parallel_sort(values, lessOp);
|
||||
}
|
||||
|
||||
template <typename T, class Container, class BinaryCompare>
|
||||
VTKM_CONT static void Sort(vtkm::cont::ArrayHandle<T, Container>& values,
|
||||
BinaryCompare binary_compare)
|
||||
{
|
||||
vtkm::cont::tbb::internal::parallel_sort(values, binary_compare);
|
||||
vtkm::cont::tbb::sort::parallel_sort(values, binary_compare);
|
||||
}
|
||||
|
||||
template <typename T, typename U, class StorageT, class StorageU>
|
||||
VTKM_CONT static void SortByKey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
||||
vtkm::cont::ArrayHandle<U, StorageU>& values)
|
||||
{
|
||||
vtkm::cont::tbb::internal::parallel_sort_bykey(keys, values, std::less<T>());
|
||||
vtkm::cont::tbb::sort::parallel_sort_bykey(keys, values, std::less<T>());
|
||||
}
|
||||
|
||||
template <typename T, typename U, class StorageT, class StorageU, class BinaryCompare>
|
||||
@ -273,7 +273,7 @@ public:
|
||||
vtkm::cont::ArrayHandle<U, StorageU>& values,
|
||||
BinaryCompare binary_compare)
|
||||
{
|
||||
vtkm::cont::tbb::internal::parallel_sort_bykey(keys, values, binary_compare);
|
||||
vtkm::cont::tbb::sort::parallel_sort_bykey(keys, values, binary_compare);
|
||||
}
|
||||
|
||||
template <typename T, class Storage>
|
||||
|
@ -49,36 +49,7 @@
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
// Modifications of Takuya Akiba's original GitHub source code for inclusion
|
||||
// in VTK-m:
|
||||
//
|
||||
// - Changed parallel threading from OpenMP to TBB tasks
|
||||
// - Added minimum threshold for parallel, will instead invoke serial radix sort (kxsort)
|
||||
// - Added std::greater<T> and std::less<T> to interface for descending order sorts
|
||||
// - Added linear scaling of threads used by the algorithm for more stable performance
|
||||
// on machines with lots of available threads (KNL and Haswell)
|
||||
//
|
||||
// This file contains an implementation of Satish parallel radix sort
|
||||
// as documented in the following citation:
|
||||
//
|
||||
// Fast sort on CPUs and GPUs: a case for bandwidth oblivious SIMD sort.
|
||||
// N. Satish, C. Kim, J. Chhugani, A. D. Nguyen, V. W. Lee, D. Kim, and P. Dubey.
|
||||
// In Proc. SIGMOD, pages 351–362, 2010
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <stdint.h>
|
||||
#include <utility>
|
||||
|
||||
#include <vtkm/Types.h>
|
||||
#include <vtkm/cont/tbb/internal/ParallelSortTBB.h>
|
||||
|
||||
VTKM_THIRDPARTY_PRE_INCLUDE
|
||||
|
||||
#include <vtkm/cont/tbb/internal/kxsort.h>
|
||||
#include <vtkm/cont/internal/ParallelRadixSort.h>
|
||||
|
||||
#if defined(VTKM_MSVC)
|
||||
|
||||
@ -107,852 +78,57 @@ namespace cont
|
||||
{
|
||||
namespace tbb
|
||||
{
|
||||
namespace internal
|
||||
namespace sort
|
||||
{
|
||||
|
||||
const size_t MIN_BYTES_FOR_PARALLEL = 400000;
|
||||
const size_t BYTES_FOR_MAX_PARALLELISM = 4000000;
|
||||
const size_t MAX_CORES = ::tbb::tbb_thread::hardware_concurrency();
|
||||
const double CORES_PER_BYTE =
|
||||
double(MAX_CORES - 1) / double(BYTES_FOR_MAX_PARALLELISM - MIN_BYTES_FOR_PARALLEL);
|
||||
const double Y_INTERCEPT = 1.0 - CORES_PER_BYTE * MIN_BYTES_FOR_PARALLEL;
|
||||
|
||||
namespace utility
|
||||
// Simple TBB task wrapper around a generic functor.
|
||||
template <typename FunctorType>
|
||||
struct TaskWrapper : public ::tbb::task
|
||||
{
|
||||
// Return the number of threads that would be executed in parallel regions
|
||||
inline size_t GetMaxThreads(size_t num_bytes)
|
||||
{
|
||||
size_t num_cores = (size_t)(CORES_PER_BYTE * double(num_bytes) + Y_INTERCEPT);
|
||||
if (num_cores < 1)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
if (num_cores > MAX_CORES)
|
||||
{
|
||||
return MAX_CORES;
|
||||
}
|
||||
return num_cores;
|
||||
}
|
||||
} // namespace utility
|
||||
FunctorType Functor;
|
||||
|
||||
namespace internal
|
||||
{
|
||||
// Size of the software managed buffer
|
||||
const size_t kOutBufferSize = 32;
|
||||
|
||||
// Ascending order radix sort is a no-op
|
||||
template <typename PlainType,
|
||||
typename UnsignedType,
|
||||
typename CompareType,
|
||||
typename ValueManager,
|
||||
unsigned int Base>
|
||||
struct ParallelRadixCompareInternal
|
||||
{
|
||||
inline static void reverse(UnsignedType& t) { (void)t; }
|
||||
};
|
||||
|
||||
// Handle descending order radix sort
|
||||
template <typename PlainType, typename UnsignedType, typename ValueManager, unsigned int Base>
|
||||
struct ParallelRadixCompareInternal<PlainType,
|
||||
UnsignedType,
|
||||
std::greater<PlainType>,
|
||||
ValueManager,
|
||||
Base>
|
||||
{
|
||||
inline static void reverse(UnsignedType& t) { t = ((1 << Base) - 1) - t; }
|
||||
};
|
||||
|
||||
// The algorithm is implemented in this internal class
|
||||
template <typename PlainType,
|
||||
typename CompareType,
|
||||
typename UnsignedType,
|
||||
typename Encoder,
|
||||
typename ValueManager,
|
||||
unsigned int Base>
|
||||
class ParallelRadixSortInternal
|
||||
{
|
||||
public:
|
||||
using CompareInternal =
|
||||
ParallelRadixCompareInternal<PlainType, UnsignedType, CompareType, ValueManager, Base>;
|
||||
|
||||
ParallelRadixSortInternal();
|
||||
~ParallelRadixSortInternal();
|
||||
|
||||
void Init(PlainType* data, size_t num_elems);
|
||||
|
||||
PlainType* Sort(PlainType* data, ValueManager* value_manager);
|
||||
|
||||
static void InitAndSort(PlainType* data, size_t num_elems, ValueManager* value_manager);
|
||||
|
||||
private:
|
||||
CompareInternal compare_internal_;
|
||||
size_t num_elems_;
|
||||
size_t num_threads_;
|
||||
|
||||
UnsignedType* tmp_;
|
||||
size_t** histo_;
|
||||
UnsignedType*** out_buf_;
|
||||
size_t** out_buf_n_;
|
||||
|
||||
size_t *pos_bgn_, *pos_end_;
|
||||
ValueManager* value_manager_;
|
||||
|
||||
void DeleteAll();
|
||||
|
||||
UnsignedType* SortInternal(UnsignedType* data, ValueManager* value_manager);
|
||||
|
||||
// Compute |pos_bgn_| and |pos_end_| (associated ranges for each threads)
|
||||
void ComputeRanges();
|
||||
|
||||
// First step of each iteration of sorting
|
||||
// Compute the histogram of |src| using bits in [b, b + Base)
|
||||
void ComputeHistogram(unsigned int b, UnsignedType* src);
|
||||
|
||||
// Second step of each iteration of sorting
|
||||
// Scatter elements of |src| to |dst| using the histogram
|
||||
void Scatter(unsigned int b, UnsignedType* src, UnsignedType* dst);
|
||||
};
|
||||
|
||||
template <typename PlainType,
|
||||
typename CompareType,
|
||||
typename UnsignedType,
|
||||
typename Encoder,
|
||||
typename ValueManager,
|
||||
unsigned int Base>
|
||||
ParallelRadixSortInternal<PlainType, CompareType, UnsignedType, Encoder, ValueManager, Base>::
|
||||
ParallelRadixSortInternal()
|
||||
: num_elems_(0)
|
||||
, num_threads_(0)
|
||||
, tmp_(NULL)
|
||||
, histo_(NULL)
|
||||
, out_buf_(NULL)
|
||||
, out_buf_n_(NULL)
|
||||
, pos_bgn_(NULL)
|
||||
, pos_end_(NULL)
|
||||
{
|
||||
assert(sizeof(PlainType) == sizeof(UnsignedType));
|
||||
}
|
||||
|
||||
template <typename PlainType,
|
||||
typename CompareType,
|
||||
typename UnsignedType,
|
||||
typename Encoder,
|
||||
typename ValueManager,
|
||||
unsigned int Base>
|
||||
ParallelRadixSortInternal<PlainType, CompareType, UnsignedType, Encoder, ValueManager, Base>::
|
||||
~ParallelRadixSortInternal()
|
||||
{
|
||||
DeleteAll();
|
||||
}
|
||||
|
||||
template <typename PlainType,
|
||||
typename CompareType,
|
||||
typename UnsignedType,
|
||||
typename Encoder,
|
||||
typename ValueManager,
|
||||
unsigned int Base>
|
||||
void ParallelRadixSortInternal<PlainType, CompareType, UnsignedType, Encoder, ValueManager, Base>::
|
||||
DeleteAll()
|
||||
{
|
||||
delete[] tmp_;
|
||||
tmp_ = NULL;
|
||||
|
||||
for (size_t i = 0; i < num_threads_; ++i)
|
||||
delete[] histo_[i];
|
||||
delete[] histo_;
|
||||
histo_ = NULL;
|
||||
|
||||
for (size_t i = 0; i < num_threads_; ++i)
|
||||
{
|
||||
for (size_t j = 0; j < 1 << Base; ++j)
|
||||
{
|
||||
delete[] out_buf_[i][j];
|
||||
}
|
||||
delete[] out_buf_n_[i];
|
||||
delete[] out_buf_[i];
|
||||
}
|
||||
delete[] out_buf_;
|
||||
delete[] out_buf_n_;
|
||||
out_buf_ = NULL;
|
||||
out_buf_n_ = NULL;
|
||||
|
||||
delete[] pos_bgn_;
|
||||
delete[] pos_end_;
|
||||
pos_bgn_ = pos_end_ = NULL;
|
||||
|
||||
num_elems_ = 0;
|
||||
num_threads_ = 0;
|
||||
}
|
||||
|
||||
template <typename PlainType,
|
||||
typename CompareType,
|
||||
typename UnsignedType,
|
||||
typename Encoder,
|
||||
typename ValueManager,
|
||||
unsigned int Base>
|
||||
void ParallelRadixSortInternal<PlainType, CompareType, UnsignedType, Encoder, ValueManager, Base>::
|
||||
Init(PlainType* data, size_t num_elems)
|
||||
{
|
||||
(void)data;
|
||||
DeleteAll();
|
||||
|
||||
num_elems_ = num_elems;
|
||||
|
||||
num_threads_ = utility::GetMaxThreads(num_elems_ * sizeof(PlainType));
|
||||
|
||||
tmp_ = new UnsignedType[num_elems_];
|
||||
histo_ = new size_t*[num_threads_];
|
||||
for (size_t i = 0; i < num_threads_; ++i)
|
||||
{
|
||||
histo_[i] = new size_t[1 << Base];
|
||||
}
|
||||
|
||||
out_buf_ = new UnsignedType**[num_threads_];
|
||||
out_buf_n_ = new size_t*[num_threads_];
|
||||
for (size_t i = 0; i < num_threads_; ++i)
|
||||
{
|
||||
out_buf_[i] = new UnsignedType*[1 << Base];
|
||||
out_buf_n_[i] = new size_t[1 << Base];
|
||||
for (size_t j = 0; j < 1 << Base; ++j)
|
||||
{
|
||||
out_buf_[i][j] = new UnsignedType[kOutBufferSize];
|
||||
}
|
||||
}
|
||||
|
||||
pos_bgn_ = new size_t[num_threads_];
|
||||
pos_end_ = new size_t[num_threads_];
|
||||
}
|
||||
|
||||
template <typename PlainType,
|
||||
typename CompareType,
|
||||
typename UnsignedType,
|
||||
typename Encoder,
|
||||
typename ValueManager,
|
||||
unsigned int Base>
|
||||
PlainType*
|
||||
ParallelRadixSortInternal<PlainType, CompareType, UnsignedType, Encoder, ValueManager, Base>::Sort(
|
||||
PlainType* data,
|
||||
ValueManager* value_manager)
|
||||
{
|
||||
UnsignedType* src = reinterpret_cast<UnsignedType*>(data);
|
||||
UnsignedType* res = SortInternal(src, value_manager);
|
||||
return reinterpret_cast<PlainType*>(res);
|
||||
}
|
||||
|
||||
template <typename PlainType,
|
||||
typename CompareType,
|
||||
typename UnsignedType,
|
||||
typename Encoder,
|
||||
typename ValueManager,
|
||||
unsigned int Base>
|
||||
void ParallelRadixSortInternal<PlainType, CompareType, UnsignedType, Encoder, ValueManager, Base>::
|
||||
InitAndSort(PlainType* data, size_t num_elems, ValueManager* value_manager)
|
||||
{
|
||||
ParallelRadixSortInternal prs;
|
||||
prs.Init(data, num_elems);
|
||||
const PlainType* res = prs.Sort(data, value_manager);
|
||||
if (res != data)
|
||||
{
|
||||
for (size_t i = 0; i < num_elems; ++i)
|
||||
data[i] = res[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PlainType,
|
||||
typename CompareType,
|
||||
typename UnsignedType,
|
||||
typename Encoder,
|
||||
typename ValueManager,
|
||||
unsigned int Base>
|
||||
UnsignedType*
|
||||
ParallelRadixSortInternal<PlainType, CompareType, UnsignedType, Encoder, ValueManager, Base>::
|
||||
SortInternal(UnsignedType* data, ValueManager* value_manager)
|
||||
{
|
||||
|
||||
value_manager_ = value_manager;
|
||||
|
||||
// Compute |pos_bgn_| and |pos_end_|
|
||||
ComputeRanges();
|
||||
|
||||
// Iterate from lower bits to higher bits
|
||||
const size_t bits = CHAR_BIT * sizeof(UnsignedType);
|
||||
UnsignedType *src = data, *dst = tmp_;
|
||||
for (unsigned int b = 0; b < bits; b += Base)
|
||||
{
|
||||
ComputeHistogram(b, src);
|
||||
Scatter(b, src, dst);
|
||||
|
||||
std::swap(src, dst);
|
||||
value_manager->Next();
|
||||
}
|
||||
|
||||
return src;
|
||||
}
|
||||
|
||||
template <typename PlainType,
|
||||
typename CompareType,
|
||||
typename UnsignedType,
|
||||
typename Encoder,
|
||||
typename ValueManager,
|
||||
unsigned int Base>
|
||||
void ParallelRadixSortInternal<PlainType, CompareType, UnsignedType, Encoder, ValueManager, Base>::
|
||||
ComputeRanges()
|
||||
{
|
||||
pos_bgn_[0] = 0;
|
||||
for (size_t i = 0; i < num_threads_ - 1; ++i)
|
||||
{
|
||||
const size_t t = (num_elems_ - pos_bgn_[i]) / (num_threads_ - i);
|
||||
pos_bgn_[i + 1] = pos_end_[i] = pos_bgn_[i] + t;
|
||||
}
|
||||
pos_end_[num_threads_ - 1] = num_elems_;
|
||||
}
|
||||
|
||||
template <typename PlainType,
|
||||
typename UnsignedType,
|
||||
typename Encoder,
|
||||
unsigned int Base,
|
||||
typename Function>
|
||||
class RunTask : public ::tbb::task
|
||||
{
|
||||
public:
|
||||
RunTask(size_t binary_tree_height,
|
||||
size_t binary_tree_position,
|
||||
Function f,
|
||||
size_t num_elems,
|
||||
size_t num_threads)
|
||||
: binary_tree_height_(binary_tree_height)
|
||||
, binary_tree_position_(binary_tree_position)
|
||||
, f_(f)
|
||||
, num_elems_(num_elems)
|
||||
, num_threads_(num_threads)
|
||||
TaskWrapper(FunctorType f)
|
||||
: Functor(f)
|
||||
{
|
||||
}
|
||||
|
||||
::tbb::task* execute()
|
||||
{
|
||||
size_t num_nodes_at_current_height = (size_t)pow(2, (double)binary_tree_height_);
|
||||
if (num_threads_ <= num_nodes_at_current_height)
|
||||
{
|
||||
const size_t my_id = binary_tree_position_ - num_nodes_at_current_height;
|
||||
if (my_id < num_threads_)
|
||||
{
|
||||
f_(my_id);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
else
|
||||
{
|
||||
::tbb::empty_task& p = *new (task::allocate_continuation())::tbb::empty_task();
|
||||
RunTask& left = *new (p.allocate_child()) RunTask(
|
||||
binary_tree_height_ + 1, 2 * binary_tree_position_, f_, num_elems_, num_threads_);
|
||||
RunTask& right = *new (p.allocate_child()) RunTask(
|
||||
binary_tree_height_ + 1, 2 * binary_tree_position_ + 1, f_, num_elems_, num_threads_);
|
||||
p.set_ref_count(2);
|
||||
task::spawn(left);
|
||||
task::spawn(right);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
size_t binary_tree_height_;
|
||||
size_t binary_tree_position_;
|
||||
Function f_;
|
||||
size_t num_elems_;
|
||||
size_t num_threads_;
|
||||
};
|
||||
|
||||
template <typename PlainType,
|
||||
typename CompareType,
|
||||
typename UnsignedType,
|
||||
typename Encoder,
|
||||
typename ValueManager,
|
||||
unsigned int Base>
|
||||
void ParallelRadixSortInternal<PlainType, CompareType, UnsignedType, Encoder, ValueManager, Base>::
|
||||
ComputeHistogram(unsigned int b, UnsignedType* src)
|
||||
{
|
||||
// Compute local histogram
|
||||
|
||||
auto lambda = [=](const size_t my_id) {
|
||||
const size_t my_bgn = pos_bgn_[my_id];
|
||||
const size_t my_end = pos_end_[my_id];
|
||||
size_t* my_histo = histo_[my_id];
|
||||
|
||||
memset(my_histo, 0, sizeof(size_t) * (1 << Base));
|
||||
for (size_t i = my_bgn; i < my_end; ++i)
|
||||
{
|
||||
const UnsignedType s = Encoder::encode(src[i]);
|
||||
UnsignedType t = (s >> b) & ((1 << Base) - 1);
|
||||
compare_internal_.reverse(t);
|
||||
++my_histo[t];
|
||||
}
|
||||
};
|
||||
|
||||
using RunTaskType = RunTask<PlainType, UnsignedType, Encoder, Base, std::function<void(size_t)>>;
|
||||
|
||||
RunTaskType& root =
|
||||
*new (::tbb::task::allocate_root()) RunTaskType(0, 1, lambda, num_elems_, num_threads_);
|
||||
|
||||
::tbb::task::spawn_root_and_wait(root);
|
||||
|
||||
// Compute global histogram
|
||||
size_t s = 0;
|
||||
for (size_t i = 0; i < 1 << Base; ++i)
|
||||
{
|
||||
for (size_t j = 0; j < num_threads_; ++j)
|
||||
{
|
||||
const size_t t = s + histo_[j][i];
|
||||
histo_[j][i] = s;
|
||||
s = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PlainType,
|
||||
typename CompareType,
|
||||
typename UnsignedType,
|
||||
typename Encoder,
|
||||
typename ValueManager,
|
||||
unsigned int Base>
|
||||
void ParallelRadixSortInternal<PlainType, CompareType, UnsignedType, Encoder, ValueManager, Base>::
|
||||
Scatter(unsigned int b, UnsignedType* src, UnsignedType* dst)
|
||||
{
|
||||
|
||||
auto lambda = [=](const size_t my_id) {
|
||||
const size_t my_bgn = pos_bgn_[my_id];
|
||||
const size_t my_end = pos_end_[my_id];
|
||||
size_t* my_histo = histo_[my_id];
|
||||
UnsignedType** my_buf = out_buf_[my_id];
|
||||
size_t* my_buf_n = out_buf_n_[my_id];
|
||||
|
||||
memset(my_buf_n, 0, sizeof(size_t) * (1 << Base));
|
||||
for (size_t i = my_bgn; i < my_end; ++i)
|
||||
{
|
||||
const UnsignedType s = Encoder::encode(src[i]);
|
||||
UnsignedType t = (s >> b) & ((1 << Base) - 1);
|
||||
compare_internal_.reverse(t);
|
||||
my_buf[t][my_buf_n[t]] = src[i];
|
||||
value_manager_->Push(my_id, t, my_buf_n[t], i);
|
||||
++my_buf_n[t];
|
||||
|
||||
if (my_buf_n[t] == kOutBufferSize)
|
||||
{
|
||||
size_t p = my_histo[t];
|
||||
for (size_t j = 0; j < kOutBufferSize; ++j)
|
||||
{
|
||||
size_t tp = p++;
|
||||
dst[tp] = my_buf[t][j];
|
||||
}
|
||||
value_manager_->Flush(my_id, t, kOutBufferSize, my_histo[t]);
|
||||
|
||||
my_histo[t] += kOutBufferSize;
|
||||
my_buf_n[t] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Flush everything
|
||||
for (size_t i = 0; i < 1 << Base; ++i)
|
||||
{
|
||||
size_t p = my_histo[i];
|
||||
for (size_t j = 0; j < my_buf_n[i]; ++j)
|
||||
{
|
||||
size_t tp = p++;
|
||||
dst[tp] = my_buf[i][j];
|
||||
}
|
||||
value_manager_->Flush(my_id, i, my_buf_n[i], my_histo[i]);
|
||||
}
|
||||
};
|
||||
|
||||
using RunTaskType = RunTask<PlainType, UnsignedType, Encoder, Base, std::function<void(size_t)>>;
|
||||
|
||||
RunTaskType& root =
|
||||
*new (::tbb::task::allocate_root()) RunTaskType(0, 1, lambda, num_elems_, num_threads_);
|
||||
|
||||
::tbb::task::spawn_root_and_wait(root);
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
// Encoders encode signed/unsigned integers and floating point numbers
|
||||
// to correctly ordered unsigned integers
|
||||
namespace encoder
|
||||
{
|
||||
class EncoderDummy
|
||||
{
|
||||
};
|
||||
|
||||
class EncoderUnsigned
|
||||
{
|
||||
public:
|
||||
template <typename UnsignedType>
|
||||
inline static UnsignedType encode(UnsignedType x)
|
||||
{
|
||||
return x;
|
||||
this->Functor(this);
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
class EncoderSigned
|
||||
struct RadixThreaderTBB
|
||||
{
|
||||
public:
|
||||
template <typename UnsignedType>
|
||||
inline static UnsignedType encode(UnsignedType x)
|
||||
size_t GetAvailableCores() const { return MAX_CORES; }
|
||||
|
||||
template <typename TaskType>
|
||||
void RunParentTask(TaskType task)
|
||||
{
|
||||
return x ^ (UnsignedType(1) << (CHAR_BIT * sizeof(UnsignedType) - 1));
|
||||
using Task = TaskWrapper<TaskType>;
|
||||
Task& root = *new (::tbb::task::allocate_root()) Task(task);
|
||||
::tbb::task::spawn_root_and_wait(root);
|
||||
}
|
||||
|
||||
template <typename TaskType>
|
||||
void RunChildTasks(TaskWrapper<TaskType>* wrapper, TaskType left, TaskType right)
|
||||
{
|
||||
using Task = TaskWrapper<TaskType>;
|
||||
::tbb::empty_task& p = *new (wrapper->allocate_continuation())::tbb::empty_task();
|
||||
|
||||
Task& lchild = *new (p.allocate_child()) Task(left);
|
||||
Task& rchild = *new (p.allocate_child()) Task(right);
|
||||
p.set_ref_count(2);
|
||||
::tbb::task::spawn(lchild);
|
||||
::tbb::task::spawn(rchild);
|
||||
}
|
||||
};
|
||||
|
||||
class EncoderDecimal
|
||||
{
|
||||
public:
|
||||
template <typename UnsignedType>
|
||||
inline static UnsignedType encode(UnsignedType x)
|
||||
{
|
||||
static const size_t bits = CHAR_BIT * sizeof(UnsignedType);
|
||||
const UnsignedType a = x >> (bits - 1);
|
||||
const UnsignedType b = (-static_cast<int>(a)) | (UnsignedType(1) << (bits - 1));
|
||||
return x ^ b;
|
||||
}
|
||||
};
|
||||
} // namespace encoder
|
||||
|
||||
// Value managers are used to generalize the sorting algorithm
|
||||
// to sorting of keys and sorting of pairs
|
||||
namespace value_manager
|
||||
{
|
||||
class DummyValueManager
|
||||
{
|
||||
public:
|
||||
inline void Push(int thread, size_t bucket, size_t num, size_t from_pos)
|
||||
{
|
||||
(void)thread;
|
||||
(void)bucket;
|
||||
(void)num;
|
||||
(void)from_pos;
|
||||
}
|
||||
|
||||
inline void Flush(int thread, size_t bucket, size_t num, size_t to_pos)
|
||||
{
|
||||
(void)thread;
|
||||
(void)bucket;
|
||||
(void)num;
|
||||
(void)to_pos;
|
||||
}
|
||||
|
||||
void Next() {}
|
||||
};
|
||||
|
||||
template <typename PlainType, typename ValueType, int Base>
|
||||
class PairValueManager
|
||||
{
|
||||
public:
|
||||
PairValueManager()
|
||||
: max_elems_(0)
|
||||
, max_threads_(0)
|
||||
, original_(NULL)
|
||||
, tmp_(NULL)
|
||||
, src_(NULL)
|
||||
, dst_(NULL)
|
||||
, out_buf_(NULL)
|
||||
{
|
||||
}
|
||||
|
||||
~PairValueManager() { DeleteAll(); }
|
||||
|
||||
void Init(size_t max_elems);
|
||||
|
||||
void Start(ValueType* original, size_t num_elems)
|
||||
{
|
||||
assert(num_elems <= max_elems_);
|
||||
src_ = original_ = original;
|
||||
dst_ = tmp_;
|
||||
}
|
||||
|
||||
inline void Push(int thread, size_t bucket, size_t num, size_t from_pos)
|
||||
{
|
||||
out_buf_[thread][bucket][num] = src_[from_pos];
|
||||
}
|
||||
|
||||
inline void Flush(int thread, size_t bucket, size_t num, size_t to_pos)
|
||||
{
|
||||
for (size_t i = 0; i < num; ++i)
|
||||
{
|
||||
dst_[to_pos++] = out_buf_[thread][bucket][i];
|
||||
}
|
||||
}
|
||||
|
||||
void Next() { std::swap(src_, dst_); }
|
||||
|
||||
ValueType* GetResult() { return src_; }
|
||||
private:
|
||||
size_t max_elems_;
|
||||
int max_threads_;
|
||||
|
||||
static constexpr size_t kOutBufferSize = internal::kOutBufferSize;
|
||||
ValueType *original_, *tmp_;
|
||||
ValueType *src_, *dst_;
|
||||
ValueType*** out_buf_;
|
||||
|
||||
void DeleteAll();
|
||||
};
|
||||
|
||||
template <typename PlainType, typename ValueType, int Base>
|
||||
void PairValueManager<PlainType, ValueType, Base>::Init(size_t max_elems)
|
||||
{
|
||||
DeleteAll();
|
||||
|
||||
max_elems_ = max_elems;
|
||||
max_threads_ = utility::GetMaxThreads(max_elems_ * sizeof(PlainType));
|
||||
|
||||
tmp_ = new ValueType[max_elems];
|
||||
|
||||
out_buf_ = new ValueType**[max_threads_];
|
||||
for (int i = 0; i < max_threads_; ++i)
|
||||
{
|
||||
out_buf_[i] = new ValueType*[1 << Base];
|
||||
for (size_t j = 0; j < 1 << Base; ++j)
|
||||
{
|
||||
out_buf_[i][j] = new ValueType[kOutBufferSize];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PlainType, typename ValueType, int Base>
|
||||
void PairValueManager<PlainType, ValueType, Base>::DeleteAll()
|
||||
{
|
||||
delete[] tmp_;
|
||||
tmp_ = NULL;
|
||||
|
||||
for (int i = 0; i < max_threads_; ++i)
|
||||
{
|
||||
for (size_t j = 0; j < 1 << Base; ++j)
|
||||
{
|
||||
delete[] out_buf_[i][j];
|
||||
}
|
||||
delete[] out_buf_[i];
|
||||
}
|
||||
delete[] out_buf_;
|
||||
out_buf_ = NULL;
|
||||
|
||||
max_elems_ = 0;
|
||||
max_threads_ = 0;
|
||||
}
|
||||
} // namespace value_manager
|
||||
|
||||
// Frontend class for sorting keys
|
||||
template <typename PlainType,
|
||||
typename CompareType,
|
||||
typename UnsignedType = PlainType,
|
||||
typename Encoder = encoder::EncoderDummy,
|
||||
unsigned int Base = 8>
|
||||
class KeySort
|
||||
{
|
||||
using DummyValueManager = value_manager::DummyValueManager;
|
||||
using Internal = internal::ParallelRadixSortInternal<PlainType,
|
||||
CompareType,
|
||||
UnsignedType,
|
||||
Encoder,
|
||||
DummyValueManager,
|
||||
Base>;
|
||||
|
||||
public:
|
||||
void InitAndSort(PlainType* data, size_t num_elems, const CompareType& comp)
|
||||
{
|
||||
(void)comp;
|
||||
DummyValueManager dvm;
|
||||
Internal::InitAndSort(data, num_elems, &dvm);
|
||||
}
|
||||
};
|
||||
|
||||
// Frontend class for sorting pairs
|
||||
template <typename PlainType,
|
||||
typename ValueType,
|
||||
typename CompareType,
|
||||
typename UnsignedType = PlainType,
|
||||
typename Encoder = encoder::EncoderDummy,
|
||||
int Base = 8>
|
||||
class PairSort
|
||||
{
|
||||
using ValueManager = value_manager::PairValueManager<PlainType, ValueType, Base>;
|
||||
using Internal = internal::
|
||||
ParallelRadixSortInternal<PlainType, CompareType, UnsignedType, Encoder, ValueManager, Base>;
|
||||
|
||||
public:
|
||||
void InitAndSort(PlainType* keys, ValueType* vals, size_t num_elems, const CompareType& comp)
|
||||
{
|
||||
(void)comp;
|
||||
ValueManager vm;
|
||||
vm.Init(num_elems);
|
||||
vm.Start(vals, num_elems);
|
||||
Internal::InitAndSort(keys, num_elems, &vm);
|
||||
ValueType* res_vals = vm.GetResult();
|
||||
if (res_vals != vals)
|
||||
{
|
||||
for (size_t i = 0; i < num_elems; ++i)
|
||||
{
|
||||
vals[i] = res_vals[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
#define KEY_SORT_CASE(plain_type, compare_type, unsigned_type, encoder_type) \
|
||||
template <> \
|
||||
class KeySort<plain_type, compare_type> \
|
||||
: public KeySort<plain_type, compare_type, unsigned_type, encoder::Encoder##encoder_type> \
|
||||
{ \
|
||||
}; \
|
||||
template <typename V> \
|
||||
class PairSort<plain_type, V, compare_type> \
|
||||
: public PairSort<plain_type, V, compare_type, unsigned_type, encoder::Encoder##encoder_type> \
|
||||
{ \
|
||||
};
|
||||
|
||||
// Unsigned integers
|
||||
KEY_SORT_CASE(unsigned int, std::less<unsigned int>, unsigned int, Unsigned);
|
||||
KEY_SORT_CASE(unsigned int, std::greater<unsigned int>, unsigned int, Unsigned);
|
||||
KEY_SORT_CASE(unsigned short int, std::less<unsigned short int>, unsigned short int, Unsigned);
|
||||
KEY_SORT_CASE(unsigned short int, std::greater<unsigned short int>, unsigned short int, Unsigned);
|
||||
KEY_SORT_CASE(unsigned long int, std::less<unsigned long int>, unsigned long int, Unsigned);
|
||||
KEY_SORT_CASE(unsigned long int, std::greater<unsigned long int>, unsigned long int, Unsigned);
|
||||
KEY_SORT_CASE(unsigned long long int,
|
||||
std::less<unsigned long long int>,
|
||||
unsigned long long int,
|
||||
Unsigned);
|
||||
KEY_SORT_CASE(unsigned long long int,
|
||||
std::greater<unsigned long long int>,
|
||||
unsigned long long int,
|
||||
Unsigned);
|
||||
|
||||
// Unsigned char
|
||||
KEY_SORT_CASE(unsigned char, std::less<unsigned char>, unsigned char, Unsigned);
|
||||
KEY_SORT_CASE(unsigned char, std::greater<unsigned char>, unsigned char, Unsigned);
|
||||
KEY_SORT_CASE(char16_t, std::less<char16_t>, uint16_t, Unsigned);
|
||||
KEY_SORT_CASE(char16_t, std::greater<char16_t>, uint16_t, Unsigned);
|
||||
KEY_SORT_CASE(char32_t, std::less<char32_t>, uint32_t, Unsigned);
|
||||
KEY_SORT_CASE(char32_t, std::greater<char32_t>, uint32_t, Unsigned);
|
||||
KEY_SORT_CASE(wchar_t, std::less<wchar_t>, uint32_t, Unsigned);
|
||||
KEY_SORT_CASE(wchar_t, std::greater<wchar_t>, uint32_t, Unsigned);
|
||||
|
||||
// Signed integers
|
||||
KEY_SORT_CASE(char, std::less<char>, unsigned char, Signed);
|
||||
KEY_SORT_CASE(char, std::greater<char>, unsigned char, Signed);
|
||||
KEY_SORT_CASE(short, std::less<short>, unsigned short, Signed);
|
||||
KEY_SORT_CASE(short, std::greater<short>, unsigned short, Signed);
|
||||
KEY_SORT_CASE(int, std::less<int>, unsigned int, Signed);
|
||||
KEY_SORT_CASE(int, std::greater<int>, unsigned int, Signed);
|
||||
KEY_SORT_CASE(long, std::less<long>, unsigned long, Signed);
|
||||
KEY_SORT_CASE(long, std::greater<long>, unsigned long, Signed);
|
||||
KEY_SORT_CASE(long long, std::less<long long>, unsigned long long, Signed);
|
||||
KEY_SORT_CASE(long long, std::greater<long long>, unsigned long long, Signed);
|
||||
|
||||
// |signed char| and |char| are treated as different types
|
||||
KEY_SORT_CASE(signed char, std::less<signed char>, unsigned char, Signed);
|
||||
KEY_SORT_CASE(signed char, std::greater<signed char>, unsigned char, Signed);
|
||||
|
||||
// Floating point numbers
|
||||
KEY_SORT_CASE(float, std::less<float>, uint32_t, Decimal);
|
||||
KEY_SORT_CASE(float, std::greater<float>, uint32_t, Decimal);
|
||||
KEY_SORT_CASE(double, std::less<double>, uint64_t, Decimal);
|
||||
KEY_SORT_CASE(double, std::greater<double>, uint64_t, Decimal);
|
||||
|
||||
#undef KEY_SORT_CASE
|
||||
|
||||
template <typename T, typename CompareType>
|
||||
struct run_kx_radix_sort_keys
|
||||
{
|
||||
static void run(T* data, size_t num_elems, const CompareType& comp)
|
||||
{
|
||||
std::sort(data, data + num_elems, comp);
|
||||
}
|
||||
};
|
||||
|
||||
#define KX_SORT_KEYS(key_type) \
|
||||
template <> \
|
||||
struct run_kx_radix_sort_keys<key_type, std::less<key_type>> \
|
||||
{ \
|
||||
static void run(key_type* data, size_t num_elems, const std::less<key_type>& comp) \
|
||||
{ \
|
||||
(void)comp; \
|
||||
kx::radix_sort(data, data + num_elems); \
|
||||
} \
|
||||
};
|
||||
|
||||
KX_SORT_KEYS(unsigned short int);
|
||||
KX_SORT_KEYS(int);
|
||||
KX_SORT_KEYS(unsigned int);
|
||||
KX_SORT_KEYS(long int);
|
||||
KX_SORT_KEYS(unsigned long int);
|
||||
KX_SORT_KEYS(long long int);
|
||||
KX_SORT_KEYS(unsigned long long int);
|
||||
KX_SORT_KEYS(unsigned char);
|
||||
|
||||
#undef KX_SORT_KEYS
|
||||
|
||||
template <typename T, typename CompareType>
|
||||
bool use_serial_sort_keys(T* data, size_t num_elems, const CompareType& comp)
|
||||
{
|
||||
size_t total_bytes = (num_elems) * sizeof(T);
|
||||
if (total_bytes < MIN_BYTES_FOR_PARALLEL)
|
||||
{
|
||||
run_kx_radix_sort_keys<T, CompareType>::run(data, num_elems, comp);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Generate radix sort interfaces for key and key value sorts.
|
||||
#define VTKM_TBB_SORT_EXPORT(key_type) \
|
||||
void parallel_radix_sort_key_values( \
|
||||
key_type* keys, vtkm::Id* vals, size_t num_elems, const std::greater<key_type>& comp) \
|
||||
{ \
|
||||
PairSort<key_type, vtkm::Id, std::greater<key_type>> ps; \
|
||||
ps.InitAndSort(keys, vals, num_elems, comp); \
|
||||
} \
|
||||
void parallel_radix_sort_key_values( \
|
||||
key_type* keys, vtkm::Id* vals, size_t num_elems, const std::less<key_type>& comp) \
|
||||
{ \
|
||||
PairSort<key_type, vtkm::Id, std::less<key_type>> ps; \
|
||||
ps.InitAndSort(keys, vals, num_elems, comp); \
|
||||
} \
|
||||
void parallel_radix_sort(key_type* data, size_t num_elems, const std::greater<key_type>& comp) \
|
||||
{ \
|
||||
if (!use_serial_sort_keys(data, num_elems, comp)) \
|
||||
{ \
|
||||
KeySort<key_type, std::greater<key_type>> ks; \
|
||||
ks.InitAndSort(data, num_elems, comp); \
|
||||
} \
|
||||
} \
|
||||
void parallel_radix_sort(key_type* data, size_t num_elems, const std::less<key_type>& comp) \
|
||||
{ \
|
||||
if (!use_serial_sort_keys(data, num_elems, comp)) \
|
||||
{ \
|
||||
KeySort<key_type, std::less<key_type>> ks; \
|
||||
ks.InitAndSort(data, num_elems, comp); \
|
||||
} \
|
||||
}
|
||||
|
||||
VTKM_TBB_SORT_EXPORT(short int);
|
||||
VTKM_TBB_SORT_EXPORT(unsigned short int);
|
||||
VTKM_TBB_SORT_EXPORT(int);
|
||||
VTKM_TBB_SORT_EXPORT(unsigned int);
|
||||
VTKM_TBB_SORT_EXPORT(long int);
|
||||
VTKM_TBB_SORT_EXPORT(unsigned long int);
|
||||
VTKM_TBB_SORT_EXPORT(long long int);
|
||||
VTKM_TBB_SORT_EXPORT(unsigned long long int);
|
||||
VTKM_TBB_SORT_EXPORT(unsigned char);
|
||||
VTKM_TBB_SORT_EXPORT(signed char);
|
||||
VTKM_TBB_SORT_EXPORT(char);
|
||||
VTKM_TBB_SORT_EXPORT(char16_t);
|
||||
VTKM_TBB_SORT_EXPORT(char32_t);
|
||||
VTKM_TBB_SORT_EXPORT(wchar_t);
|
||||
VTKM_TBB_SORT_EXPORT(float);
|
||||
VTKM_TBB_SORT_EXPORT(double);
|
||||
|
||||
#undef VTKM_TBB_SORT_EXPORT
|
||||
|
||||
VTKM_THIRDPARTY_POST_INCLUDE
|
||||
}
|
||||
VTKM_INSTANTIATE_RADIX_SORT_FOR_THREADER(RadixThreaderTBB)
|
||||
}
|
||||
}
|
||||
}
|
||||
} // vtkm::cont::tbb::sort
|
||||
|
@ -24,6 +24,7 @@
|
||||
#include <vtkm/BinaryPredicates.h>
|
||||
#include <vtkm/cont/ArrayHandle.h>
|
||||
#include <vtkm/cont/ArrayHandleZip.h>
|
||||
#include <vtkm/cont/internal/ParallelRadixSortInterface.h>
|
||||
|
||||
#include <vtkm/cont/tbb/internal/ArrayManagerExecutionTBB.h>
|
||||
#include <vtkm/cont/tbb/internal/DeviceAdapterTagTBB.h>
|
||||
@ -38,128 +39,27 @@ namespace cont
|
||||
{
|
||||
namespace tbb
|
||||
{
|
||||
namespace internal
|
||||
namespace sort
|
||||
{
|
||||
struct RadixSortTag
|
||||
{
|
||||
};
|
||||
struct PSortTag
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_valid_compare_type : std::integral_constant<bool, false>
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct is_valid_compare_type<std::less<T>> : std::integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct is_valid_compare_type<std::greater<T>> : std::integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct is_valid_compare_type<vtkm::SortLess> : std::integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct is_valid_compare_type<vtkm::SortGreater> : std::integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
template <typename BComp, typename T>
|
||||
BComp&& get_std_compare(BComp&& b, T&&)
|
||||
{
|
||||
return std::forward<BComp>(b);
|
||||
}
|
||||
template <typename T>
|
||||
std::less<T> get_std_compare(vtkm::SortLess, T&&)
|
||||
{
|
||||
return std::less<T>{};
|
||||
}
|
||||
template <typename T>
|
||||
std::greater<T> get_std_compare(vtkm::SortGreater, T&&)
|
||||
{
|
||||
return std::greater<T>{};
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename StorageTag, typename BinaryCompare>
|
||||
struct sort_tag_type
|
||||
{
|
||||
using type = PSortTag;
|
||||
};
|
||||
template <typename T, typename BinaryCompare>
|
||||
struct sort_tag_type<T, vtkm::cont::StorageTagBasic, BinaryCompare>
|
||||
{
|
||||
using PrimT = std::is_arithmetic<T>;
|
||||
using LongDT = std::is_same<T, long double>;
|
||||
using BComp = is_valid_compare_type<BinaryCompare>;
|
||||
using type = typename std::conditional<PrimT::value && BComp::value && !LongDT::value,
|
||||
RadixSortTag,
|
||||
PSortTag>::type;
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename StorageTagT, typename StorageTagU, class BinaryCompare>
|
||||
struct sortbykey_tag_type
|
||||
{
|
||||
using type = PSortTag;
|
||||
};
|
||||
template <typename T, typename U, typename BinaryCompare>
|
||||
struct sortbykey_tag_type<T,
|
||||
U,
|
||||
vtkm::cont::StorageTagBasic,
|
||||
vtkm::cont::StorageTagBasic,
|
||||
BinaryCompare>
|
||||
{
|
||||
using PrimT = std::is_arithmetic<T>;
|
||||
using PrimU = std::is_arithmetic<U>;
|
||||
using LongDT = std::is_same<T, long double>;
|
||||
using BComp = is_valid_compare_type<BinaryCompare>;
|
||||
using type =
|
||||
typename std::conditional<PrimT::value && PrimU::value && BComp::value && !LongDT::value,
|
||||
RadixSortTag,
|
||||
PSortTag>::type;
|
||||
};
|
||||
|
||||
|
||||
#define VTKM_TBB_SORT_EXPORT(key_type) \
|
||||
VTKM_CONT_EXPORT void parallel_radix_sort( \
|
||||
key_type* data, size_t num_elems, const std::greater<key_type>& comp); \
|
||||
VTKM_CONT_EXPORT void parallel_radix_sort( \
|
||||
key_type* data, size_t num_elems, const std::less<key_type>& comp); \
|
||||
VTKM_CONT_EXPORT void parallel_radix_sort_key_values( \
|
||||
key_type* keys, vtkm::Id* vals, size_t num_elems, const std::greater<key_type>& comp); \
|
||||
VTKM_CONT_EXPORT void parallel_radix_sort_key_values( \
|
||||
key_type* keys, vtkm::Id* vals, size_t num_elems, const std::less<key_type>& comp);
|
||||
|
||||
// Generate radix sort interfaces for key and key value sorts.
|
||||
VTKM_TBB_SORT_EXPORT(short int);
|
||||
VTKM_TBB_SORT_EXPORT(unsigned short int);
|
||||
VTKM_TBB_SORT_EXPORT(int);
|
||||
VTKM_TBB_SORT_EXPORT(unsigned int);
|
||||
VTKM_TBB_SORT_EXPORT(long int);
|
||||
VTKM_TBB_SORT_EXPORT(unsigned long int);
|
||||
VTKM_TBB_SORT_EXPORT(long long int);
|
||||
VTKM_TBB_SORT_EXPORT(unsigned long long int);
|
||||
VTKM_TBB_SORT_EXPORT(unsigned char);
|
||||
VTKM_TBB_SORT_EXPORT(signed char);
|
||||
VTKM_TBB_SORT_EXPORT(char);
|
||||
VTKM_TBB_SORT_EXPORT(char16_t);
|
||||
VTKM_TBB_SORT_EXPORT(char32_t);
|
||||
VTKM_TBB_SORT_EXPORT(wchar_t);
|
||||
VTKM_TBB_SORT_EXPORT(float);
|
||||
VTKM_TBB_SORT_EXPORT(double);
|
||||
#undef VTKM_TBB_SORT_EXPORT
|
||||
// Declare the compiled radix sort specializations:
|
||||
VTKM_DECLARE_RADIX_SORT()
|
||||
|
||||
// Forward declare entry points (See stack overflow discussion 7255281 --
|
||||
// templated overloads of template functions are not specialization, and will
|
||||
// be resolved during the first phase of two part lookup).
|
||||
template <typename T, typename Container, class BinaryCompare>
|
||||
void parallel_sort(vtkm::cont::ArrayHandle<T, Container>& values, BinaryCompare binary_compare)
|
||||
{
|
||||
using SortAlgorithmTag = typename sort_tag_type<T, Container, BinaryCompare>::type;
|
||||
parallel_sort(values, binary_compare, SortAlgorithmTag{});
|
||||
}
|
||||
void parallel_sort(vtkm::cont::ArrayHandle<T, Container>&, BinaryCompare);
|
||||
template <typename T, typename StorageT, typename U, typename StorageU, class BinaryCompare>
|
||||
void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>&,
|
||||
vtkm::cont::ArrayHandle<U, StorageU>&,
|
||||
BinaryCompare);
|
||||
|
||||
// Quicksort values:
|
||||
template <typename HandleType, class BinaryCompare>
|
||||
void parallel_sort(HandleType& values, BinaryCompare binary_compare, PSortTag)
|
||||
void parallel_sort(HandleType& values,
|
||||
BinaryCompare binary_compare,
|
||||
vtkm::cont::internal::radix::PSortTag)
|
||||
{
|
||||
auto arrayPortal = values.PrepareForInPlace(vtkm::cont::DeviceAdapterTagTBB());
|
||||
|
||||
@ -169,31 +69,37 @@ void parallel_sort(HandleType& values, BinaryCompare binary_compare, PSortTag)
|
||||
internal::WrappedBinaryOperator<bool, BinaryCompare> wrappedCompare(binary_compare);
|
||||
::tbb::parallel_sort(iterators.GetBegin(), iterators.GetEnd(), wrappedCompare);
|
||||
}
|
||||
|
||||
// Radix sort values:
|
||||
template <typename T, typename StorageT, class BinaryCompare>
|
||||
void parallel_sort(vtkm::cont::ArrayHandle<T, StorageT>& values,
|
||||
BinaryCompare binary_compare,
|
||||
RadixSortTag)
|
||||
vtkm::cont::internal::radix::RadixSortTag)
|
||||
{
|
||||
using namespace vtkm::cont::internal::radix;
|
||||
auto c = get_std_compare(binary_compare, T{});
|
||||
parallel_radix_sort(
|
||||
values.GetStorage().GetArray(), static_cast<std::size_t>(values.GetNumberOfValues()), c);
|
||||
}
|
||||
|
||||
template <typename T, typename StorageT, typename U, typename StorageU, class BinaryCompare>
|
||||
void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
||||
vtkm::cont::ArrayHandle<U, StorageU>& values,
|
||||
BinaryCompare binary_compare)
|
||||
// Value sort -- static switch between quicksort and radix sort
|
||||
template <typename T, typename Container, class BinaryCompare>
|
||||
void parallel_sort(vtkm::cont::ArrayHandle<T, Container>& values, BinaryCompare binary_compare)
|
||||
{
|
||||
using SortAlgorithmTag =
|
||||
typename sortbykey_tag_type<T, U, StorageT, StorageU, BinaryCompare>::type;
|
||||
parallel_sort_bykey(keys, values, binary_compare, SortAlgorithmTag{});
|
||||
using namespace vtkm::cont::internal::radix;
|
||||
using SortAlgorithmTag = typename sort_tag_type<T, Container, BinaryCompare>::type;
|
||||
parallel_sort(values, binary_compare, SortAlgorithmTag{});
|
||||
}
|
||||
|
||||
|
||||
// Quicksort by key
|
||||
template <typename T, typename StorageT, typename U, typename StorageU, class BinaryCompare>
|
||||
void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
||||
vtkm::cont::ArrayHandle<U, StorageU>& values,
|
||||
BinaryCompare binary_compare,
|
||||
PSortTag)
|
||||
vtkm::cont::internal::radix::PSortTag)
|
||||
{
|
||||
using namespace vtkm::cont::internal::radix;
|
||||
using KeyType = vtkm::cont::ArrayHandle<T, StorageT>;
|
||||
constexpr bool larger_than_64bits = sizeof(U) > sizeof(vtkm::Int64);
|
||||
if (larger_than_64bits)
|
||||
@ -243,23 +149,28 @@ void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
||||
zipHandle, vtkm::cont::internal::KeyCompare<T, U, BinaryCompare>(binary_compare), PSortTag{});
|
||||
}
|
||||
}
|
||||
|
||||
// Radix sort by key -- Specialize for vtkm::Id values:
|
||||
template <typename T, typename StorageT, typename StorageU, class BinaryCompare>
|
||||
void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
||||
vtkm::cont::ArrayHandle<vtkm::Id, StorageU>& values,
|
||||
BinaryCompare binary_compare,
|
||||
RadixSortTag)
|
||||
vtkm::cont::internal::radix::RadixSortTag)
|
||||
{
|
||||
using namespace vtkm::cont::internal::radix;
|
||||
auto c = get_std_compare(binary_compare, T{});
|
||||
parallel_radix_sort_key_values(keys.GetStorage().GetArray(),
|
||||
values.GetStorage().GetArray(),
|
||||
static_cast<std::size_t>(keys.GetNumberOfValues()),
|
||||
c);
|
||||
}
|
||||
|
||||
// Radix sort by key -- Generic impl:
|
||||
template <typename T, typename StorageT, typename U, typename StorageU, class BinaryCompare>
|
||||
void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
||||
vtkm::cont::ArrayHandle<U, StorageU>& values,
|
||||
BinaryCompare binary_compare,
|
||||
RadixSortTag)
|
||||
vtkm::cont::internal::radix::RadixSortTag)
|
||||
{
|
||||
using KeyType = vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic>;
|
||||
using ValueType = vtkm::cont::ArrayHandle<U, vtkm::cont::StorageTagBasic>;
|
||||
@ -287,7 +198,7 @@ void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
||||
ZipHandleType zipHandle = vtkm::cont::make_ArrayHandleZip(keys, indexArray);
|
||||
parallel_sort(zipHandle,
|
||||
vtkm::cont::internal::KeyCompare<T, vtkm::Id, BinaryCompare>(binary_compare),
|
||||
PSortTag{});
|
||||
vtkm::cont::internal::radix::PSortTag{});
|
||||
}
|
||||
|
||||
tbb::ScatterPortal(values.PrepareForInput(vtkm::cont::DeviceAdapterTagTBB()),
|
||||
@ -301,9 +212,21 @@ void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
||||
tbb::CopyPortals(inputPortal, outputPortal, 0, 0, valuesScattered.GetNumberOfValues());
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by key -- static switch between radix and quick sort:
|
||||
template <typename T, typename StorageT, typename U, typename StorageU, class BinaryCompare>
|
||||
void parallel_sort_bykey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
||||
vtkm::cont::ArrayHandle<U, StorageU>& values,
|
||||
BinaryCompare binary_compare)
|
||||
{
|
||||
using namespace vtkm::cont::internal::radix;
|
||||
using SortAlgorithmTag =
|
||||
typename sortbykey_tag_type<T, U, StorageT, StorageU, BinaryCompare>::type;
|
||||
parallel_sort_bykey(keys, values, binary_compare, SortAlgorithmTag{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end namespace vtkm::cont::tbb::sort
|
||||
|
||||
#endif // vtk_m_cont_tbb_internal_ParallelSort_h
|
||||
|
Loading…
Reference in New Issue
Block a user