Rewrite sorting specialization using std::enable_if_t
This commit is contained in:
parent
04013b9924
commit
5a72275ed8
@ -776,37 +776,18 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Used to define valid operators in Thrust (only used if thrust is enabled)
|
|
||||||
template <typename Compare>
|
|
||||||
struct ThrustSortByKeySupport : std::false_type
|
|
||||||
{
|
|
||||||
};
|
|
||||||
|
|
||||||
// Kokkos currently (11/10/2022) does not support a sort_by_key operator
|
// Kokkos currently (11/10/2022) does not support a sort_by_key operator
|
||||||
// so instead we are using thrust if and only if HIP or CUDA are the backends for Kokkos
|
// so instead we are using thrust if and only if HIP or CUDA are the backends for Kokkos
|
||||||
#if defined(VTKM_KOKKOS_HIP) || defined(VTKM_KOKKOS_CUDA)
|
#if defined(VTKM_KOKKOS_HIP) || defined(VTKM_KOKKOS_CUDA)
|
||||||
|
|
||||||
// Valid thrust instantiations
|
|
||||||
template <>
|
|
||||||
struct ThrustSortByKeySupport<vtkm::SortLess> : std::true_type
|
|
||||||
{
|
|
||||||
template <typename T>
|
|
||||||
using Operator = thrust::less<T>;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
struct ThrustSortByKeySupport<vtkm::SortGreater> : std::true_type
|
|
||||||
{
|
|
||||||
template <typename T>
|
|
||||||
using Operator = thrust::greater<T>;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename U, typename BinaryCompare>
|
template <typename T, typename U, typename BinaryCompare>
|
||||||
VTKM_CONT static void SortByKeyImpl(vtkm::cont::ArrayHandle<T>& keys,
|
VTKM_CONT static std::enable_if_t<(std::is_same<BinaryCompare, vtkm::SortLess>::value ||
|
||||||
vtkm::cont::ArrayHandle<U>& values,
|
std::is_same<BinaryCompare, vtkm::SortGreater>::value)>
|
||||||
BinaryCompare,
|
SortByKeyImpl(vtkm::cont::ArrayHandle<T>& keys,
|
||||||
std::true_type,
|
vtkm::cont::ArrayHandle<U>& values,
|
||||||
std::true_type,
|
BinaryCompare,
|
||||||
std::true_type)
|
std::true_type,
|
||||||
|
std::true_type)
|
||||||
{
|
{
|
||||||
vtkm::cont::Token token;
|
vtkm::cont::Token token;
|
||||||
auto keys_portal = keys.PrepareForInPlace(vtkm::cont::DeviceAdapterTagKokkos{}, token);
|
auto keys_portal = keys.PrepareForInPlace(vtkm::cont::DeviceAdapterTagKokkos{}, token);
|
||||||
@ -816,12 +797,19 @@ protected:
|
|||||||
keys_portal.GetNumberOfValues());
|
keys_portal.GetNumberOfValues());
|
||||||
kokkos::internal::KokkosViewExec<U> values_view(values_portal.GetArray(),
|
kokkos::internal::KokkosViewExec<U> values_view(values_portal.GetArray(),
|
||||||
values_portal.GetNumberOfValues());
|
values_portal.GetNumberOfValues());
|
||||||
using ThrustOperator = typename ThrustSortByKeySupport<BinaryCompare>::template Operator<T>;
|
|
||||||
|
|
||||||
thrust::device_ptr<T> keys_begin(keys_view.data());
|
thrust::device_ptr<T> keys_begin(keys_view.data());
|
||||||
thrust::device_ptr<T> keys_end(keys_view.data() + keys_view.size());
|
thrust::device_ptr<T> keys_end(keys_view.data() + keys_view.size());
|
||||||
thrust::device_ptr<U> values_begin(values_view.data());
|
thrust::device_ptr<U> values_begin(values_view.data());
|
||||||
thrust::sort_by_key(keys_begin, keys_end, values_begin, ThrustOperator());
|
|
||||||
|
if (std::is_same<BinaryCompare, vtkm::SortLess>::value)
|
||||||
|
{
|
||||||
|
thrust::sort_by_key(keys_begin, keys_end, values_begin, thrust::less<T>());
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
thrust::sort_by_key(keys_begin, keys_end, values_begin, thrust::greater<T>());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
@ -832,14 +820,12 @@ protected:
|
|||||||
class StorageU,
|
class StorageU,
|
||||||
class BinaryCompare,
|
class BinaryCompare,
|
||||||
typename ValidKeys,
|
typename ValidKeys,
|
||||||
typename ValidValues,
|
typename ValidValues>
|
||||||
typename ValidCompare>
|
|
||||||
VTKM_CONT static void SortByKeyImpl(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
VTKM_CONT static void SortByKeyImpl(vtkm::cont::ArrayHandle<T, StorageT>& keys,
|
||||||
vtkm::cont::ArrayHandle<U, StorageU>& values,
|
vtkm::cont::ArrayHandle<U, StorageU>& values,
|
||||||
BinaryCompare binary_compare,
|
BinaryCompare binary_compare,
|
||||||
ValidKeys,
|
ValidKeys,
|
||||||
ValidValues,
|
ValidValues)
|
||||||
ValidCompare)
|
|
||||||
{
|
{
|
||||||
// Default to general algorithm
|
// Default to general algorithm
|
||||||
Superclass::SortByKey(keys, values, binary_compare);
|
Superclass::SortByKey(keys, values, binary_compare);
|
||||||
@ -866,8 +852,7 @@ public:
|
|||||||
values,
|
values,
|
||||||
binary_compare,
|
binary_compare,
|
||||||
typename std::is_scalar<T>::type{},
|
typename std::is_scalar<T>::type{},
|
||||||
typename std::is_scalar<U>::type{},
|
typename std::is_scalar<U>::type{});
|
||||||
typename ThrustSortByKeySupport<BinaryCompare>::type{});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//----------------------------------------------------------------------------
|
//----------------------------------------------------------------------------
|
||||||
|
Loading…
Reference in New Issue
Block a user