Merge branch 'master' of https://gitlab.kitware.com/vtk/vtk-m
This commit is contained in:
commit
8586d9de5a
@ -10,6 +10,7 @@
|
||||
#ifndef vtk_m_cont_kokkos_internal_DeviceAdapterAlgorithmKokkos_h
|
||||
#define vtk_m_cont_kokkos_internal_DeviceAdapterAlgorithmKokkos_h
|
||||
|
||||
#include <vtkm/cont/ArrayHandleConstant.h>
|
||||
#include <vtkm/cont/ArrayHandleImplicit.h>
|
||||
#include <vtkm/cont/ArrayHandleIndex.h>
|
||||
#include <vtkm/cont/DeviceAdapterAlgorithm.h>
|
||||
@ -44,6 +45,7 @@ VTKM_THIRDPARTY_POST_INCLUDE
|
||||
|
||||
#if defined(VTKM_USE_KOKKOS_THRUST)
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/iterator/constant_iterator.h>
|
||||
#include <thrust/sort.h>
|
||||
#endif
|
||||
|
||||
@ -860,6 +862,143 @@ public:
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Reduce By Key
|
||||
|
||||
#ifdef VTKM_USE_KOKKOS_THRUST
|
||||
|
||||
protected:
|
||||
template <typename K, typename V, class BinaryFunctor>
|
||||
VTKM_CONT static void ReduceByKeyImpl(const vtkm::cont::ArrayHandle<K>& keys,
|
||||
const vtkm::cont::ArrayHandle<V>& values,
|
||||
vtkm::cont::ArrayHandle<K>& keys_output,
|
||||
vtkm::cont::ArrayHandle<V>& values_output,
|
||||
BinaryFunctor binary_functor)
|
||||
{
|
||||
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
|
||||
|
||||
const vtkm::Id numberOfKeys = keys.GetNumberOfValues();
|
||||
|
||||
vtkm::Id num_unique_keys;
|
||||
{
|
||||
vtkm::cont::Token token;
|
||||
|
||||
auto keys_portal = keys.PrepareForInput(vtkm::cont::DeviceAdapterTagKokkos{}, token);
|
||||
auto values_portal = values.PrepareForInput(vtkm::cont::DeviceAdapterTagKokkos{}, token);
|
||||
|
||||
auto keys_output_portal =
|
||||
keys_output.PrepareForOutput(numberOfKeys, vtkm::cont::DeviceAdapterTagKokkos{}, token);
|
||||
auto values_output_portal =
|
||||
values_output.PrepareForOutput(numberOfKeys, vtkm::cont::DeviceAdapterTagKokkos{}, token);
|
||||
|
||||
thrust::device_ptr<const K> keys_begin(keys_portal.GetArray());
|
||||
thrust::device_ptr<const K> keys_end(keys_portal.GetArray() + numberOfKeys);
|
||||
thrust::device_ptr<const V> values_begin(values_portal.GetArray());
|
||||
thrust::device_ptr<K> keys_output_begin(keys_output_portal.GetArray());
|
||||
thrust::device_ptr<V> values_output_begin(values_output_portal.GetArray());
|
||||
|
||||
auto ends = thrust::reduce_by_key(keys_begin,
|
||||
keys_end,
|
||||
values_begin,
|
||||
keys_output_begin,
|
||||
values_output_begin,
|
||||
thrust::equal_to<K>(),
|
||||
binary_functor);
|
||||
|
||||
num_unique_keys = ends.first - keys_output_begin;
|
||||
}
|
||||
|
||||
// Resize output (reduce allocation)
|
||||
keys_output.Allocate(num_unique_keys, CopyFlag::On);
|
||||
values_output.Allocate(num_unique_keys, CopyFlag::On);
|
||||
}
|
||||
|
||||
|
||||
template <typename K, typename V, class BinaryFunctor>
|
||||
VTKM_CONT static void ReduceByKeyImpl(
|
||||
const vtkm::cont::ArrayHandle<K>& keys,
|
||||
const vtkm::cont::ArrayHandle<V, vtkm::cont::StorageTagConstant>& values,
|
||||
vtkm::cont::ArrayHandle<K>& keys_output,
|
||||
vtkm::cont::ArrayHandle<V>& values_output,
|
||||
BinaryFunctor binary_functor)
|
||||
{
|
||||
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
|
||||
|
||||
const vtkm::Id numberOfKeys = keys.GetNumberOfValues();
|
||||
|
||||
vtkm::Id num_unique_keys;
|
||||
{
|
||||
vtkm::cont::Token token;
|
||||
|
||||
auto keys_portal = keys.PrepareForInput(vtkm::cont::DeviceAdapterTagKokkos{}, token);
|
||||
auto value = values.ReadPortal().Get(0);
|
||||
|
||||
auto keys_output_portal =
|
||||
keys_output.PrepareForOutput(numberOfKeys, vtkm::cont::DeviceAdapterTagKokkos{}, token);
|
||||
auto values_output_portal =
|
||||
values_output.PrepareForOutput(numberOfKeys, vtkm::cont::DeviceAdapterTagKokkos{}, token);
|
||||
|
||||
thrust::device_ptr<const K> keys_begin(keys_portal.GetArray());
|
||||
thrust::device_ptr<const K> keys_end(keys_portal.GetArray() + numberOfKeys);
|
||||
thrust::constant_iterator<const V> values_begin(value);
|
||||
thrust::device_ptr<K> keys_output_begin(keys_output_portal.GetArray());
|
||||
thrust::device_ptr<V> values_output_begin(values_output_portal.GetArray());
|
||||
|
||||
auto ends = thrust::reduce_by_key(keys_begin,
|
||||
keys_end,
|
||||
values_begin,
|
||||
keys_output_begin,
|
||||
values_output_begin,
|
||||
thrust::equal_to<K>(),
|
||||
binary_functor);
|
||||
|
||||
num_unique_keys = ends.first - keys_output_begin;
|
||||
}
|
||||
|
||||
// Resize output (reduce allocation)
|
||||
keys_output.Allocate(num_unique_keys, CopyFlag::On);
|
||||
values_output.Allocate(num_unique_keys, CopyFlag::On);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename U,
|
||||
class KIn,
|
||||
class VIn,
|
||||
class KOut,
|
||||
class VOut,
|
||||
class BinaryFunctor>
|
||||
VTKM_CONT static void ReduceByKeyImpl(const vtkm::cont::ArrayHandle<T, KIn>& keys,
|
||||
const vtkm::cont::ArrayHandle<U, VIn>& values,
|
||||
vtkm::cont::ArrayHandle<T, KOut>& keys_output,
|
||||
vtkm::cont::ArrayHandle<U, VOut>& values_output,
|
||||
BinaryFunctor binary_functor)
|
||||
{
|
||||
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
|
||||
|
||||
Superclass::ReduceByKey(keys, values, keys_output, values_output, binary_functor);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename T,
|
||||
typename U,
|
||||
class KIn,
|
||||
class VIn,
|
||||
class KOut,
|
||||
class VOut,
|
||||
class BinaryFunctor>
|
||||
VTKM_CONT static void ReduceByKey(const vtkm::cont::ArrayHandle<T, KIn>& keys,
|
||||
const vtkm::cont::ArrayHandle<U, VIn>& values,
|
||||
vtkm::cont::ArrayHandle<T, KOut>& keys_output,
|
||||
vtkm::cont::ArrayHandle<U, VOut>& values_output,
|
||||
BinaryFunctor binary_functor)
|
||||
{
|
||||
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
|
||||
|
||||
ReduceByKeyImpl(keys, values, keys_output, values_output, binary_functor);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
//--------------------------------------------------------------------------
|
||||
|
||||
VTKM_CONT static void Synchronize()
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user