Merge topic 'thrust-reduce-by-key'

1ac12b784 Adding thrust option for ReduceByKey for Kokkos+HIP/CUDA.

Acked-by: Kitware Robot <kwrobot@kitware.com>
Merge-request: !3027
This commit is contained in:
Thomas H. Gibson 2023-04-18 21:35:15 +00:00 committed by Kitware Robot
commit 141d0e70ef

@ -10,6 +10,7 @@
#ifndef vtk_m_cont_kokkos_internal_DeviceAdapterAlgorithmKokkos_h
#define vtk_m_cont_kokkos_internal_DeviceAdapterAlgorithmKokkos_h
#include <vtkm/cont/ArrayHandleConstant.h>
#include <vtkm/cont/ArrayHandleImplicit.h>
#include <vtkm/cont/ArrayHandleIndex.h>
#include <vtkm/cont/DeviceAdapterAlgorithm.h>
@ -44,6 +45,7 @@ VTKM_THIRDPARTY_POST_INCLUDE
#if defined(VTKM_USE_KOKKOS_THRUST)
#include <thrust/device_ptr.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/sort.h>
#endif
@ -860,6 +862,143 @@ public:
}
//----------------------------------------------------------------------------
// Reduce By Key
#ifdef VTKM_USE_KOKKOS_THRUST
protected:
template <typename K, typename V, class BinaryFunctor>
VTKM_CONT static void ReduceByKeyImpl(const vtkm::cont::ArrayHandle<K>& keys,
const vtkm::cont::ArrayHandle<V>& values,
vtkm::cont::ArrayHandle<K>& keys_output,
vtkm::cont::ArrayHandle<V>& values_output,
BinaryFunctor binary_functor)
{
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
const vtkm::Id numberOfKeys = keys.GetNumberOfValues();
vtkm::Id num_unique_keys;
{
vtkm::cont::Token token;
auto keys_portal = keys.PrepareForInput(vtkm::cont::DeviceAdapterTagKokkos{}, token);
auto values_portal = values.PrepareForInput(vtkm::cont::DeviceAdapterTagKokkos{}, token);
auto keys_output_portal =
keys_output.PrepareForOutput(numberOfKeys, vtkm::cont::DeviceAdapterTagKokkos{}, token);
auto values_output_portal =
values_output.PrepareForOutput(numberOfKeys, vtkm::cont::DeviceAdapterTagKokkos{}, token);
thrust::device_ptr<const K> keys_begin(keys_portal.GetArray());
thrust::device_ptr<const K> keys_end(keys_portal.GetArray() + numberOfKeys);
thrust::device_ptr<const V> values_begin(values_portal.GetArray());
thrust::device_ptr<K> keys_output_begin(keys_output_portal.GetArray());
thrust::device_ptr<V> values_output_begin(values_output_portal.GetArray());
auto ends = thrust::reduce_by_key(keys_begin,
keys_end,
values_begin,
keys_output_begin,
values_output_begin,
thrust::equal_to<K>(),
binary_functor);
num_unique_keys = ends.first - keys_output_begin;
}
// Resize output (reduce allocation)
keys_output.Allocate(num_unique_keys, CopyFlag::On);
values_output.Allocate(num_unique_keys, CopyFlag::On);
}
template <typename K, typename V, class BinaryFunctor>
VTKM_CONT static void ReduceByKeyImpl(
const vtkm::cont::ArrayHandle<K>& keys,
const vtkm::cont::ArrayHandle<V, vtkm::cont::StorageTagConstant>& values,
vtkm::cont::ArrayHandle<K>& keys_output,
vtkm::cont::ArrayHandle<V>& values_output,
BinaryFunctor binary_functor)
{
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
const vtkm::Id numberOfKeys = keys.GetNumberOfValues();
vtkm::Id num_unique_keys;
{
vtkm::cont::Token token;
auto keys_portal = keys.PrepareForInput(vtkm::cont::DeviceAdapterTagKokkos{}, token);
auto value = values.ReadPortal().Get(0);
auto keys_output_portal =
keys_output.PrepareForOutput(numberOfKeys, vtkm::cont::DeviceAdapterTagKokkos{}, token);
auto values_output_portal =
values_output.PrepareForOutput(numberOfKeys, vtkm::cont::DeviceAdapterTagKokkos{}, token);
thrust::device_ptr<const K> keys_begin(keys_portal.GetArray());
thrust::device_ptr<const K> keys_end(keys_portal.GetArray() + numberOfKeys);
thrust::constant_iterator<const V> values_begin(value);
thrust::device_ptr<K> keys_output_begin(keys_output_portal.GetArray());
thrust::device_ptr<V> values_output_begin(values_output_portal.GetArray());
auto ends = thrust::reduce_by_key(keys_begin,
keys_end,
values_begin,
keys_output_begin,
values_output_begin,
thrust::equal_to<K>(),
binary_functor);
num_unique_keys = ends.first - keys_output_begin;
}
// Resize output (reduce allocation)
keys_output.Allocate(num_unique_keys, CopyFlag::On);
values_output.Allocate(num_unique_keys, CopyFlag::On);
}
template <typename T,
typename U,
class KIn,
class VIn,
class KOut,
class VOut,
class BinaryFunctor>
VTKM_CONT static void ReduceByKeyImpl(const vtkm::cont::ArrayHandle<T, KIn>& keys,
const vtkm::cont::ArrayHandle<U, VIn>& values,
vtkm::cont::ArrayHandle<T, KOut>& keys_output,
vtkm::cont::ArrayHandle<U, VOut>& values_output,
BinaryFunctor binary_functor)
{
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
Superclass::ReduceByKey(keys, values, keys_output, values_output, binary_functor);
}
public:
template <typename T,
typename U,
class KIn,
class VIn,
class KOut,
class VOut,
class BinaryFunctor>
VTKM_CONT static void ReduceByKey(const vtkm::cont::ArrayHandle<T, KIn>& keys,
const vtkm::cont::ArrayHandle<U, VIn>& values,
vtkm::cont::ArrayHandle<T, KOut>& keys_output,
vtkm::cont::ArrayHandle<U, VOut>& values_output,
BinaryFunctor binary_functor)
{
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
ReduceByKeyImpl(keys, values, keys_output, values_output, binary_functor);
}
#endif
//--------------------------------------------------------------------------
VTKM_CONT static void Synchronize()
{