diff --git a/vtkm/cont/kokkos/internal/DeviceAdapterAlgorithmKokkos.h b/vtkm/cont/kokkos/internal/DeviceAdapterAlgorithmKokkos.h index b5b6e6a26..2c9e1cf07 100644 --- a/vtkm/cont/kokkos/internal/DeviceAdapterAlgorithmKokkos.h +++ b/vtkm/cont/kokkos/internal/DeviceAdapterAlgorithmKokkos.h @@ -38,6 +38,10 @@ VTKM_THIRDPARTY_POST_INCLUDE #define VTKM_VOLATILE volatile #endif +#if defined(VTKM_KOKKOS_HIP) || defined(VTKM_KOKKOS_CUDA) +#include +#include +#endif namespace vtkm { @@ -771,6 +775,103 @@ public: SortImpl(values, comp, typename std::is_scalar::type{}); } +protected: + // Used to define valid operators in Thrust (only used if thrust is enabled) + template + struct ThrustSortByKeySupport : std::false_type + { + }; + + // Kokkos currently (11/10/2022) does not support a sort_by_key operator + // so instead we are using thrust if and only if HIP or CUDA are the backends for Kokkos +#if defined(VTKM_KOKKOS_HIP) || defined(VTKM_KOKKOS_CUDA) + + // Valid thrust instantiations + template <> + struct ThrustSortByKeySupport : std::true_type + { + template + using Operator = thrust::less; + }; + template <> + struct ThrustSortByKeySupport : std::true_type + { + template + using Operator = thrust::greater; + }; + + template + VTKM_CONT static void SortByKeyImpl(vtkm::cont::ArrayHandle& keys, + vtkm::cont::ArrayHandle& values, + BinaryCompare, + std::true_type, + std::true_type, + std::true_type) + { + vtkm::cont::Token token; + auto keys_portal = keys.PrepareForInPlace(vtkm::cont::DeviceAdapterTagKokkos{}, token); + auto values_portal = values.PrepareForInPlace(vtkm::cont::DeviceAdapterTagKokkos{}, token); + + kokkos::internal::KokkosViewExec keys_view(keys_portal.GetArray(), + keys_portal.GetNumberOfValues()); + kokkos::internal::KokkosViewExec values_view(values_portal.GetArray(), + values_portal.GetNumberOfValues()); + using ThrustOperator = typename ThrustSortByKeySupport::template Operator; + + thrust::device_ptr keys_begin(keys_view.data()); + thrust::device_ptr keys_end(keys_view.data() + keys_view.size()); + thrust::device_ptr values_begin(values_view.data()); + thrust::sort_by_key(keys_begin, keys_end, values_begin, ThrustOperator()); + } + +#endif + + template + VTKM_CONT static void SortByKeyImpl(vtkm::cont::ArrayHandle& keys, + vtkm::cont::ArrayHandle& values, + BinaryCompare binary_compare, + ValidKeys, + ValidValues, + ValidCompare) + { + // Default to general algorithm + Superclass::SortByKey(keys, values, binary_compare); + } + +public: + template + VTKM_CONT static void SortByKey(vtkm::cont::ArrayHandle& keys, + vtkm::cont::ArrayHandle& values) + { + // Make sure not to use the general algorithm here since + // it will use Sort algorithm instead of SortByKey + SortByKey(keys, values, internal::DefaultCompareFunctor()); + } + + template + VTKM_CONT static void SortByKey(vtkm::cont::ArrayHandle& keys, + vtkm::cont::ArrayHandle& values, + BinaryCompare binary_compare) + { + // If T or U are not scalar types, or the BinaryCompare is not supported + // then the general algorithm is called, otherwise we will run thrust + SortByKeyImpl(keys, + values, + binary_compare, + typename std::is_scalar::type{}, + typename std::is_scalar::type{}, + typename ThrustSortByKeySupport::type{}); + } + + //---------------------------------------------------------------------------- + VTKM_CONT static void Synchronize() { vtkm::cont::kokkos::internal::GetExecutionSpaceInstance().fence();