Rewrite sorting specialization using std::enable_if_t

This commit is contained in:
Thomas Gibson 2022-12-08 21:14:29 -06:00
parent 04013b9924
commit 5a72275ed8

@ -776,37 +776,18 @@ public:
} }
protected: protected:
// Used to define valid operators in Thrust (only used if thrust is enabled)
template <typename Compare>
struct ThrustSortByKeySupport : std::false_type
{
};
// Kokkos currently (11/10/2022) does not support a sort_by_key operator // Kokkos currently (11/10/2022) does not support a sort_by_key operator
// so instead we are using thrust if and only if HIP or CUDA are the backends for Kokkos // so instead we are using thrust if and only if HIP or CUDA are the backends for Kokkos
#if defined(VTKM_KOKKOS_HIP) || defined(VTKM_KOKKOS_CUDA) #if defined(VTKM_KOKKOS_HIP) || defined(VTKM_KOKKOS_CUDA)
// Valid thrust instantiations
template <>
struct ThrustSortByKeySupport<vtkm::SortLess> : std::true_type
{
template <typename T>
using Operator = thrust::less<T>;
};
template <>
struct ThrustSortByKeySupport<vtkm::SortGreater> : std::true_type
{
template <typename T>
using Operator = thrust::greater<T>;
};
template <typename T, typename U, typename BinaryCompare> template <typename T, typename U, typename BinaryCompare>
VTKM_CONT static void SortByKeyImpl(vtkm::cont::ArrayHandle<T>& keys, VTKM_CONT static std::enable_if_t<(std::is_same<BinaryCompare, vtkm::SortLess>::value ||
vtkm::cont::ArrayHandle<U>& values, std::is_same<BinaryCompare, vtkm::SortGreater>::value)>
BinaryCompare, SortByKeyImpl(vtkm::cont::ArrayHandle<T>& keys,
std::true_type, vtkm::cont::ArrayHandle<U>& values,
std::true_type, BinaryCompare,
std::true_type) std::true_type,
std::true_type)
{ {
vtkm::cont::Token token; vtkm::cont::Token token;
auto keys_portal = keys.PrepareForInPlace(vtkm::cont::DeviceAdapterTagKokkos{}, token); auto keys_portal = keys.PrepareForInPlace(vtkm::cont::DeviceAdapterTagKokkos{}, token);
@ -816,12 +797,19 @@ protected:
keys_portal.GetNumberOfValues()); keys_portal.GetNumberOfValues());
kokkos::internal::KokkosViewExec<U> values_view(values_portal.GetArray(), kokkos::internal::KokkosViewExec<U> values_view(values_portal.GetArray(),
values_portal.GetNumberOfValues()); values_portal.GetNumberOfValues());
using ThrustOperator = typename ThrustSortByKeySupport<BinaryCompare>::template Operator<T>;
thrust::device_ptr<T> keys_begin(keys_view.data()); thrust::device_ptr<T> keys_begin(keys_view.data());
thrust::device_ptr<T> keys_end(keys_view.data() + keys_view.size()); thrust::device_ptr<T> keys_end(keys_view.data() + keys_view.size());
thrust::device_ptr<U> values_begin(values_view.data()); thrust::device_ptr<U> values_begin(values_view.data());
thrust::sort_by_key(keys_begin, keys_end, values_begin, ThrustOperator());
if (std::is_same<BinaryCompare, vtkm::SortLess>::value)
{
thrust::sort_by_key(keys_begin, keys_end, values_begin, thrust::less<T>());
}
else
{
thrust::sort_by_key(keys_begin, keys_end, values_begin, thrust::greater<T>());
}
} }
#endif #endif
@ -832,14 +820,12 @@ protected:
class StorageU, class StorageU,
class BinaryCompare, class BinaryCompare,
typename ValidKeys, typename ValidKeys,
typename ValidValues, typename ValidValues>
typename ValidCompare>
VTKM_CONT static void SortByKeyImpl(vtkm::cont::ArrayHandle<T, StorageT>& keys, VTKM_CONT static void SortByKeyImpl(vtkm::cont::ArrayHandle<T, StorageT>& keys,
vtkm::cont::ArrayHandle<U, StorageU>& values, vtkm::cont::ArrayHandle<U, StorageU>& values,
BinaryCompare binary_compare, BinaryCompare binary_compare,
ValidKeys, ValidKeys,
ValidValues, ValidValues)
ValidCompare)
{ {
// Default to general algorithm // Default to general algorithm
Superclass::SortByKey(keys, values, binary_compare); Superclass::SortByKey(keys, values, binary_compare);
@ -866,8 +852,7 @@ public:
values, values,
binary_compare, binary_compare,
typename std::is_scalar<T>::type{}, typename std::is_scalar<T>::type{},
typename std::is_scalar<U>::type{}, typename std::is_scalar<U>::type{});
typename ThrustSortByKeySupport<BinaryCompare>::type{});
} }
//---------------------------------------------------------------------------- //----------------------------------------------------------------------------