Add sorting implementation using thrust
Co-authored-by: Thomas Gibson <thomas.gibson@amd.com>
This commit is contained in:
parent
cf3c9bc921
commit
04013b9924
@ -38,6 +38,10 @@ VTKM_THIRDPARTY_POST_INCLUDE
|
||||
#define VTKM_VOLATILE volatile
|
||||
#endif
|
||||
|
||||
#if defined(VTKM_KOKKOS_HIP) || defined(VTKM_KOKKOS_CUDA)
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/sort.h>
|
||||
#endif
|
||||
|
||||
namespace vtkm
|
||||
{
|
||||
@ -771,6 +775,103 @@ public:
|
||||
SortImpl(values, comp, typename std::is_scalar<T>::type{});
|
||||
}
|
||||
|
||||
protected:
|
||||
// Used to define valid operators in Thrust (only used if thrust is enabled)
|
||||
template <typename Compare>
|
||||
struct ThrustSortByKeySupport : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
// Kokkos currently (11/10/2022) does not support a sort_by_key operator
|
||||
// so instead we are using thrust if and only if HIP or CUDA are the backends for Kokkos
|
||||
#if defined(VTKM_KOKKOS_HIP) || defined(VTKM_KOKKOS_CUDA)
|
||||
|
||||
// Valid thrust instantiations
|
||||
template <>
|
||||
struct ThrustSortByKeySupport<vtkm::SortLess> : std::true_type
|
||||
{
|
||||
template <typename T>
|
||||
using Operator = thrust::less<T>;
|
||||
};
|
||||
template <>
|
||||
struct ThrustSortByKeySupport<vtkm::SortGreater> : std::true_type
|
||||
{
|
||||
template <typename T>
|
||||
using Operator = thrust::greater<T>;
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename BinaryCompare>
|
||||
VTKM_CONT static void SortByKeyImpl(vtkm::cont::ArrayHandle<T>& keys,
|
||||
vtkm::cont::ArrayHandle<U>& values,
|
||||
BinaryCompare,
|
||||
std::true_type,
|
||||
std::true_type,
|
||||
std::true_type)
|
||||
{
|
||||
vtkm::cont::Token token;
|
||||
auto keys_portal = keys.PrepareForInPlace(vtkm::cont::DeviceAdapterTagKokkos{}, token);
|
||||
auto values_portal = values.PrepareForInPlace(vtkm::cont::DeviceAdapterTagKokkos{}, token);
|
||||
|
||||
kokkos::internal::KokkosViewExec<T> keys_view(keys_portal.GetArray(),
|
||||
keys_portal.GetNumberOfValues());
|
||||
kokkos::internal::KokkosViewExec<U> values_view(values_portal.GetArray(),
|
||||
values_portal.GetNumberOfValues());
|
||||
using ThrustOperator = typename ThrustSortByKeySupport<BinaryCompare>::template Operator<T>;
|
||||
|
||||
thrust::device_ptr<T> keys_begin(keys_view.data());
|
||||
thrust::device_ptr<T> keys_end(keys_view.data() + keys_view.size());
|
||||
thrust::device_ptr<U> values_begin(values_view.data());
|
||||
thrust::sort_by_key(keys_begin, keys_end, values_begin, ThrustOperator());
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T,
|
||||
typename U,
|
||||
class StorageT,
|
||||
class StorageU,
|
||||
class BinaryCompare,
|
||||
typename ValidKeys,
|
||||
typename ValidValues,
|
||||
typename ValidCompare>
|
||||
VTKM_CONT static void SortByKeyImpl(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
||||
vtkm::cont::ArrayHandle<U, StorageU>& values,
|
||||
BinaryCompare binary_compare,
|
||||
ValidKeys,
|
||||
ValidValues,
|
||||
ValidCompare)
|
||||
{
|
||||
// Default to general algorithm
|
||||
Superclass::SortByKey(keys, values, binary_compare);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename T, typename U, class StorageT, class StorageU>
|
||||
VTKM_CONT static void SortByKey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
||||
vtkm::cont::ArrayHandle<U, StorageU>& values)
|
||||
{
|
||||
// Make sure not to use the general algorithm here since
|
||||
// it will use Sort algorithm instead of SortByKey
|
||||
SortByKey(keys, values, internal::DefaultCompareFunctor());
|
||||
}
|
||||
|
||||
template <typename T, typename U, class StorageT, class StorageU, class BinaryCompare>
|
||||
VTKM_CONT static void SortByKey(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
||||
vtkm::cont::ArrayHandle<U, StorageU>& values,
|
||||
BinaryCompare binary_compare)
|
||||
{
|
||||
// If T or U are not scalar types, or the BinaryCompare is not supported
|
||||
// then the general algorithm is called, otherwise we will run thrust
|
||||
SortByKeyImpl(keys,
|
||||
values,
|
||||
binary_compare,
|
||||
typename std::is_scalar<T>::type{},
|
||||
typename std::is_scalar<U>::type{},
|
||||
typename ThrustSortByKeySupport<BinaryCompare>::type{});
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
|
||||
VTKM_CONT static void Synchronize()
|
||||
{
|
||||
vtkm::cont::kokkos::internal::GetExecutionSpaceInstance().fence();
|
||||
|
Loading…
Reference in New Issue
Block a user