Add non-templated base class to Keys class.

The only reason Keys has a template is so that it can hold a UniqueKeys
array and provide the key for each group. If that is not needed and you
want to implement a library function that takes a keys object, you can
now grab the Keys superclass KeysBase. KeysBase is not templated, so you
can pass it to a standard method in a library.
This commit is contained in:
Kenneth Moreland 2020-02-03 23:00:22 -06:00
parent 934732bb64
commit b1f288aaea
6 changed files with 146 additions and 58 deletions

@ -25,14 +25,8 @@ struct TypeCheckTagKeys
{
};
// A more specific specialization that actually checks for Keys types is
// implemented in vtkm/worklet/Keys.h. That class is not accessible from here
// due to VTK-m package dependencies.
template <typename Type>
struct TypeCheck<TypeCheckTagKeys, Type>
{
static constexpr bool value = false;
};
// The specialization that actually checks for Keys types is implemented in vtkm/worklet/Keys.h.
// That class is not accessible from here due to VTK-m package dependencies.
}
}
} // namespace vtkm::cont::arg

@ -33,13 +33,13 @@ class ThreadIndicesReduceByKey : public vtkm::exec::arg::ThreadIndicesBasic
using Superclass = vtkm::exec::arg::ThreadIndicesBasic;
public:
template <typename P1, typename P2, typename P3>
template <typename P1, typename P2>
VTKM_EXEC ThreadIndicesReduceByKey(
vtkm::Id threadIndex,
vtkm::Id inIndex,
vtkm::IdComponent visitIndex,
vtkm::Id outIndex,
const vtkm::exec::internal::ReduceByKeyLookup<P1, P2, P3>& keyLookup)
const vtkm::exec::internal::ReduceByKeyLookupBase<P1, P2>& keyLookup)
: Superclass(threadIndex, inIndex, visitIndex, outIndex)
, ValueOffset(keyLookup.Offsets.Get(inIndex))
, NumberOfValues(keyLookup.Counts.Get(inIndex))

@ -24,6 +24,34 @@ namespace exec
namespace internal
{
/// A superclass of `ReduceBykeyLookup` that can be used when no key values are provided.
///
template <typename IdPortalType, typename IdComponentPortalType>
struct ReduceByKeyLookupBase
{
VTKM_STATIC_ASSERT((std::is_same<typename IdPortalType::ValueType, vtkm::Id>::value));
VTKM_STATIC_ASSERT(
(std::is_same<typename IdComponentPortalType::ValueType, vtkm::IdComponent>::value));
IdPortalType SortedValuesMap;
IdPortalType Offsets;
IdComponentPortalType Counts;
VTKM_EXEC_CONT
ReduceByKeyLookupBase(const IdPortalType& sortedValuesMap,
const IdPortalType& offsets,
const IdComponentPortalType& counts)
: SortedValuesMap(sortedValuesMap)
, Offsets(offsets)
, Counts(counts)
{
}
VTKM_SUPPRESS_EXEC_WARNINGS
VTKM_EXEC_CONT
ReduceByKeyLookupBase() {}
};
/// \brief Execution object holding lookup info for reduce by key.
///
/// A WorkletReduceByKey needs several arrays to map the current output object
@ -31,28 +59,19 @@ namespace internal
/// state.
///
template <typename KeyPortalType, typename IdPortalType, typename IdComponentPortalType>
struct ReduceByKeyLookup : vtkm::cont::ExecutionObjectBase
struct ReduceByKeyLookup : ReduceByKeyLookupBase<IdPortalType, IdComponentPortalType>
{
using KeyType = typename KeyPortalType::ValueType;
VTKM_STATIC_ASSERT((std::is_same<typename IdPortalType::ValueType, vtkm::Id>::value));
VTKM_STATIC_ASSERT(
(std::is_same<typename IdComponentPortalType::ValueType, vtkm::IdComponent>::value));
KeyPortalType UniqueKeys;
IdPortalType SortedValuesMap;
IdPortalType Offsets;
IdComponentPortalType Counts;
VTKM_EXEC_CONT
ReduceByKeyLookup(const KeyPortalType& uniqueKeys,
const IdPortalType& sortedValuesMap,
const IdPortalType& offsets,
const IdComponentPortalType& counts)
: UniqueKeys(uniqueKeys)
, SortedValuesMap(sortedValuesMap)
, Offsets(offsets)
, Counts(counts)
: ReduceByKeyLookupBase<IdPortalType, IdComponentPortalType>(sortedValuesMap, offsets, counts)
, UniqueKeys(uniqueKeys)
{
}

@ -72,11 +72,8 @@ struct AverageByKey
/// This method uses an existing \c Keys object to collected values by those keys and find
/// the average of those groups.
///
template <typename KeyType,
typename ValueType,
typename InValuesStorage,
typename OutAveragesStorage>
VTKM_CONT static void Run(const vtkm::worklet::Keys<KeyType>& keys,
template <typename ValueType, typename InValuesStorage, typename OutAveragesStorage>
VTKM_CONT static void Run(const vtkm::worklet::internal::KeysBase& keys,
const vtkm::cont::ArrayHandle<ValueType, InValuesStorage>& inValues,
vtkm::cont::ArrayHandle<ValueType, OutAveragesStorage>& outAverages)
{
@ -90,9 +87,9 @@ struct AverageByKey
/// This method uses an existing \c Keys object to collected values by those keys and find
/// the average of those groups.
///
template <typename KeyType, typename ValueType, typename InValuesStorage>
template <typename ValueType, typename InValuesStorage>
VTKM_CONT static vtkm::cont::ArrayHandle<ValueType> Run(
const vtkm::worklet::Keys<KeyType>& keys,
const vtkm::worklet::internal::KeysBase& keys,
const vtkm::cont::ArrayHandle<ValueType, InValuesStorage>& inValues)
{

@ -30,6 +30,8 @@
#include <vtkm/cont/arg/TransportTagKeysIn.h>
#include <vtkm/cont/arg/TypeCheckTagKeys.h>
#include <vtkm/worklet/internal/DispatcherBase.h>
#include <vtkm/worklet/StableSortIndices.h>
#include <vtkm/worklet/vtkm_worklet_export.h>
@ -40,6 +42,83 @@ namespace vtkm
namespace worklet
{
namespace internal
{
class VTKM_WORKLET_EXPORT KeysBase
{
public:
KeysBase(const KeysBase&) = default;
KeysBase& operator=(const KeysBase&) = default;
~KeysBase() = default;
VTKM_CONT
vtkm::Id GetInputRange() const { return this->Counts.GetNumberOfValues(); }
VTKM_CONT
vtkm::cont::ArrayHandle<vtkm::Id> GetSortedValuesMap() const { return this->SortedValuesMap; }
VTKM_CONT
vtkm::cont::ArrayHandle<vtkm::Id> GetOffsets() const { return this->Offsets; }
VTKM_CONT
vtkm::cont::ArrayHandle<vtkm::IdComponent> GetCounts() const { return this->Counts; }
VTKM_CONT
vtkm::Id GetNumberOfValues() const { return this->SortedValuesMap.GetNumberOfValues(); }
template <typename Device>
struct ExecutionTypes
{
using IdPortal =
typename vtkm::cont::ArrayHandle<vtkm::Id>::template ExecutionTypes<Device>::PortalConst;
using IdComponentPortal = typename vtkm::cont::ArrayHandle<
vtkm::IdComponent>::template ExecutionTypes<Device>::PortalConst;
using Lookup = vtkm::exec::internal::ReduceByKeyLookupBase<IdPortal, IdComponentPortal>;
};
template <typename Device>
VTKM_CONT typename ExecutionTypes<Device>::Lookup PrepareForInput(Device device,
vtkm::cont::Token& token) const
{
return
typename ExecutionTypes<Device>::Lookup(this->SortedValuesMap.PrepareForInput(device, token),
this->Offsets.PrepareForInput(device, token),
this->Counts.PrepareForInput(device, token));
}
template <typename Device>
VTKM_CONT VTKM_DEPRECATED(1.6, "PrepareForInput now requires a vtkm::cont::Token object.")
typename ExecutionTypes<Device>::Lookup PrepareForInput(Device device) const
{
vtkm::cont::Token token;
return this->PrepareForInput(device, token);
}
VTKM_CONT
bool operator==(const vtkm::worklet::internal::KeysBase& other) const
{
return ((this->SortedValuesMap == other.SortedValuesMap) && (this->Offsets == other.Offsets) &&
(this->Counts == other.Counts));
}
VTKM_CONT
bool operator!=(const vtkm::worklet::internal::KeysBase& other) const
{
return !(*this == other);
}
protected:
KeysBase() = default;
vtkm::cont::ArrayHandle<vtkm::Id> SortedValuesMap;
vtkm::cont::ArrayHandle<vtkm::Id> Offsets;
vtkm::cont::ArrayHandle<vtkm::IdComponent> Counts;
};
} // namespace internal
/// Select the type of sort for BuildArrays calls. Unstable sorting is faster
/// but will not produce consistent ordering for equal keys. Stable sorting
/// is slower, but keeps equal keys in their original order.
@ -67,7 +146,7 @@ enum class KeysSortType
/// creating a different \c Keys structure for each \c Invoke.
///
template <typename T>
class VTKM_ALWAYS_EXPORT Keys
class VTKM_ALWAYS_EXPORT Keys : public internal::KeysBase
{
public:
using KeyType = T;
@ -110,24 +189,9 @@ public:
KeysSortType sort,
vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny());
VTKM_CONT
vtkm::Id GetInputRange() const { return this->UniqueKeys.GetNumberOfValues(); }
VTKM_CONT
KeyArrayHandleType GetUniqueKeys() const { return this->UniqueKeys; }
VTKM_CONT
vtkm::cont::ArrayHandle<vtkm::Id> GetSortedValuesMap() const { return this->SortedValuesMap; }
VTKM_CONT
vtkm::cont::ArrayHandle<vtkm::Id> GetOffsets() const { return this->Offsets; }
VTKM_CONT
vtkm::cont::ArrayHandle<vtkm::IdComponent> GetCounts() const { return this->Counts; }
VTKM_CONT
vtkm::Id GetNumberOfValues() const { return this->SortedValuesMap.GetNumberOfValues(); }
template <typename Device>
struct ExecutionTypes
{
@ -173,9 +237,6 @@ public:
private:
/// @cond NONE
KeyArrayHandleType UniqueKeys;
vtkm::cont::ArrayHandle<vtkm::Id> SortedValuesMap;
vtkm::cont::ArrayHandle<vtkm::Id> Offsets;
vtkm::cont::ArrayHandle<vtkm::IdComponent> Counts;
template <typename KeyArrayType>
VTKM_CONT void BuildArraysInternal(KeyArrayType& keys, vtkm::cont::DeviceAdapterId device);
@ -189,6 +250,9 @@ private:
template <typename T>
VTKM_CONT Keys<T>::Keys() = default;
namespace internal
{
template <typename KeyType>
inline auto SchedulingRange(const vtkm::worklet::Keys<KeyType>& inputDomain)
-> decltype(inputDomain.GetInputRange())
@ -202,6 +266,19 @@ inline auto SchedulingRange(const vtkm::worklet::Keys<KeyType>* const inputDomai
{
return inputDomain->GetInputRange();
}
inline auto SchedulingRange(const vtkm::worklet::internal::KeysBase& inputDomain)
-> decltype(inputDomain.GetInputRange())
{
return inputDomain.GetInputRange();
}
inline auto SchedulingRange(const vtkm::worklet::internal::KeysBase* const inputDomain)
-> decltype(inputDomain->GetInputRange())
{
return inputDomain->GetInputRange();
}
} // namespace internal
}
} // namespace vtkm::worklet
@ -218,15 +295,16 @@ namespace arg
{
template <typename KeyType>
struct TypeCheck<vtkm::cont::arg::TypeCheckTagKeys, vtkm::worklet::Keys<KeyType>>
struct TypeCheck<vtkm::cont::arg::TypeCheckTagKeys, KeyType>
{
static constexpr bool value = true;
static constexpr bool value =
std::is_base_of<vtkm::worklet::internal::KeysBase, typename std::decay<KeyType>::type>::value;
};
template <typename KeyType, typename Device>
struct Transport<vtkm::cont::arg::TransportTagKeysIn, vtkm::worklet::Keys<KeyType>, Device>
struct Transport<vtkm::cont::arg::TransportTagKeysIn, KeyType, Device>
{
using ContObjectType = vtkm::worklet::Keys<KeyType>;
using ContObjectType = KeyType;
using ExecObjectType = typename ContObjectType::template ExecutionTypes<Device>::Lookup;
VTKM_CONT
@ -264,9 +342,8 @@ struct Transport<vtkm::cont::arg::TransportTagKeyedValuesIn, ArrayHandleType, De
using ExecObjectType = typename GroupedArrayType::template ExecutionTypes<Device>::PortalConst;
template <typename KeyType>
VTKM_CONT ExecObjectType operator()(const ContObjectType& object,
const vtkm::worklet::Keys<KeyType>& keys,
const vtkm::worklet::internal::KeysBase& keys,
vtkm::Id,
vtkm::Id,
vtkm::cont::Token& token) const
@ -300,9 +377,8 @@ struct Transport<vtkm::cont::arg::TransportTagKeyedValuesInOut, ArrayHandleType,
using ExecObjectType = typename GroupedArrayType::template ExecutionTypes<Device>::Portal;
template <typename KeyType>
VTKM_CONT ExecObjectType operator()(ContObjectType object,
const vtkm::worklet::Keys<KeyType>& keys,
const vtkm::worklet::internal::KeysBase& keys,
vtkm::Id,
vtkm::Id,
vtkm::cont::Token& token) const
@ -336,9 +412,8 @@ struct Transport<vtkm::cont::arg::TransportTagKeyedValuesOut, ArrayHandleType, D
using ExecObjectType = typename GroupedArrayType::template ExecutionTypes<Device>::Portal;
template <typename KeyType>
VTKM_CONT ExecObjectType operator()(ContObjectType object,
const vtkm::worklet::Keys<KeyType>& keys,
const vtkm::worklet::internal::KeysBase& keys,
vtkm::Id,
vtkm::Id,
vtkm::cont::Token& token) const

@ -128,6 +128,9 @@ void TryKeyType(KeyType)
vtkm::cont::ArrayCopy(keyArray, sortedKeys);
vtkm::worklet::Keys<KeyType> keys(sortedKeys);
vtkm::cont::printSummary_ArrayHandle(keys.GetUniqueKeys(), std::cout);
vtkm::cont::printSummary_ArrayHandle(keys.GetOffsets(), std::cout);
vtkm::cont::printSummary_ArrayHandle(keys.GetCounts(), std::cout);
vtkm::cont::ArrayHandle<KeyType> valuesToModify;
valuesToModify.Allocate(ARRAY_SIZE);