From b1f288aaeaa21e6db1ce5b2765e0829966f28658 Mon Sep 17 00:00:00 2001 From: Kenneth Moreland Date: Mon, 3 Feb 2020 23:00:22 -0600 Subject: [PATCH] Add non-templated base class to Keys class. The only reason Keys has a template is so that it can hold a UniqueKeys array and provide the key for each group. If that is not needed and you want to implement a library function that takes a keys object, you can now grab the Keys superclass KeysBase. KeysBase is not templated, so you can pass it to a standard method in a library. --- vtkm/cont/arg/TypeCheckTagKeys.h | 10 +- vtkm/exec/arg/ThreadIndicesReduceByKey.h | 4 +- vtkm/exec/internal/ReduceByKeyLookup.h | 43 ++++-- vtkm/worklet/AverageByKey.h | 11 +- vtkm/worklet/Keys.h | 133 ++++++++++++++---- .../testing/UnitTestWorkletReduceByKey.cxx | 3 + 6 files changed, 146 insertions(+), 58 deletions(-) diff --git a/vtkm/cont/arg/TypeCheckTagKeys.h b/vtkm/cont/arg/TypeCheckTagKeys.h index 39c0f1c34..07b984d4e 100644 --- a/vtkm/cont/arg/TypeCheckTagKeys.h +++ b/vtkm/cont/arg/TypeCheckTagKeys.h @@ -25,14 +25,8 @@ struct TypeCheckTagKeys { }; -// A more specific specialization that actually checks for Keys types is -// implemented in vtkm/worklet/Keys.h. That class is not accessible from here -// due to VTK-m package dependencies. -template -struct TypeCheck -{ - static constexpr bool value = false; -}; +// The specialization that actually checks for Keys types is implemented in vtkm/worklet/Keys.h. +// That class is not accessible from here due to VTK-m package dependencies. } } } // namespace vtkm::cont::arg diff --git a/vtkm/exec/arg/ThreadIndicesReduceByKey.h b/vtkm/exec/arg/ThreadIndicesReduceByKey.h index 2af9bd598..cedee477e 100644 --- a/vtkm/exec/arg/ThreadIndicesReduceByKey.h +++ b/vtkm/exec/arg/ThreadIndicesReduceByKey.h @@ -33,13 +33,13 @@ class ThreadIndicesReduceByKey : public vtkm::exec::arg::ThreadIndicesBasic using Superclass = vtkm::exec::arg::ThreadIndicesBasic; public: - template + template VTKM_EXEC ThreadIndicesReduceByKey( vtkm::Id threadIndex, vtkm::Id inIndex, vtkm::IdComponent visitIndex, vtkm::Id outIndex, - const vtkm::exec::internal::ReduceByKeyLookup& keyLookup) + const vtkm::exec::internal::ReduceByKeyLookupBase& keyLookup) : Superclass(threadIndex, inIndex, visitIndex, outIndex) , ValueOffset(keyLookup.Offsets.Get(inIndex)) , NumberOfValues(keyLookup.Counts.Get(inIndex)) diff --git a/vtkm/exec/internal/ReduceByKeyLookup.h b/vtkm/exec/internal/ReduceByKeyLookup.h index 0e687bc36..7111b71fb 100644 --- a/vtkm/exec/internal/ReduceByKeyLookup.h +++ b/vtkm/exec/internal/ReduceByKeyLookup.h @@ -24,6 +24,34 @@ namespace exec namespace internal { +/// A superclass of `ReduceBykeyLookup` that can be used when no key values are provided. +/// +template +struct ReduceByKeyLookupBase +{ + VTKM_STATIC_ASSERT((std::is_same::value)); + VTKM_STATIC_ASSERT( + (std::is_same::value)); + + IdPortalType SortedValuesMap; + IdPortalType Offsets; + IdComponentPortalType Counts; + + VTKM_EXEC_CONT + ReduceByKeyLookupBase(const IdPortalType& sortedValuesMap, + const IdPortalType& offsets, + const IdComponentPortalType& counts) + : SortedValuesMap(sortedValuesMap) + , Offsets(offsets) + , Counts(counts) + { + } + + VTKM_SUPPRESS_EXEC_WARNINGS + VTKM_EXEC_CONT + ReduceByKeyLookupBase() {} +}; + /// \brief Execution object holding lookup info for reduce by key. /// /// A WorkletReduceByKey needs several arrays to map the current output object @@ -31,28 +59,19 @@ namespace internal /// state. /// template -struct ReduceByKeyLookup : vtkm::cont::ExecutionObjectBase +struct ReduceByKeyLookup : ReduceByKeyLookupBase { using KeyType = typename KeyPortalType::ValueType; - VTKM_STATIC_ASSERT((std::is_same::value)); - VTKM_STATIC_ASSERT( - (std::is_same::value)); - KeyPortalType UniqueKeys; - IdPortalType SortedValuesMap; - IdPortalType Offsets; - IdComponentPortalType Counts; VTKM_EXEC_CONT ReduceByKeyLookup(const KeyPortalType& uniqueKeys, const IdPortalType& sortedValuesMap, const IdPortalType& offsets, const IdComponentPortalType& counts) - : UniqueKeys(uniqueKeys) - , SortedValuesMap(sortedValuesMap) - , Offsets(offsets) - , Counts(counts) + : ReduceByKeyLookupBase(sortedValuesMap, offsets, counts) + , UniqueKeys(uniqueKeys) { } diff --git a/vtkm/worklet/AverageByKey.h b/vtkm/worklet/AverageByKey.h index 7f27e3e2c..4b1fbb510 100644 --- a/vtkm/worklet/AverageByKey.h +++ b/vtkm/worklet/AverageByKey.h @@ -72,11 +72,8 @@ struct AverageByKey /// This method uses an existing \c Keys object to collected values by those keys and find /// the average of those groups. /// - template - VTKM_CONT static void Run(const vtkm::worklet::Keys& keys, + template + VTKM_CONT static void Run(const vtkm::worklet::internal::KeysBase& keys, const vtkm::cont::ArrayHandle& inValues, vtkm::cont::ArrayHandle& outAverages) { @@ -90,9 +87,9 @@ struct AverageByKey /// This method uses an existing \c Keys object to collected values by those keys and find /// the average of those groups. /// - template + template VTKM_CONT static vtkm::cont::ArrayHandle Run( - const vtkm::worklet::Keys& keys, + const vtkm::worklet::internal::KeysBase& keys, const vtkm::cont::ArrayHandle& inValues) { diff --git a/vtkm/worklet/Keys.h b/vtkm/worklet/Keys.h index 86058168c..ab9b7841a 100644 --- a/vtkm/worklet/Keys.h +++ b/vtkm/worklet/Keys.h @@ -30,6 +30,8 @@ #include #include +#include + #include #include @@ -40,6 +42,83 @@ namespace vtkm namespace worklet { +namespace internal +{ + +class VTKM_WORKLET_EXPORT KeysBase +{ +public: + KeysBase(const KeysBase&) = default; + KeysBase& operator=(const KeysBase&) = default; + ~KeysBase() = default; + + VTKM_CONT + vtkm::Id GetInputRange() const { return this->Counts.GetNumberOfValues(); } + + VTKM_CONT + vtkm::cont::ArrayHandle GetSortedValuesMap() const { return this->SortedValuesMap; } + + VTKM_CONT + vtkm::cont::ArrayHandle GetOffsets() const { return this->Offsets; } + + VTKM_CONT + vtkm::cont::ArrayHandle GetCounts() const { return this->Counts; } + + VTKM_CONT + vtkm::Id GetNumberOfValues() const { return this->SortedValuesMap.GetNumberOfValues(); } + + template + struct ExecutionTypes + { + using IdPortal = + typename vtkm::cont::ArrayHandle::template ExecutionTypes::PortalConst; + using IdComponentPortal = typename vtkm::cont::ArrayHandle< + vtkm::IdComponent>::template ExecutionTypes::PortalConst; + + using Lookup = vtkm::exec::internal::ReduceByKeyLookupBase; + }; + + template + VTKM_CONT typename ExecutionTypes::Lookup PrepareForInput(Device device, + vtkm::cont::Token& token) const + { + return + typename ExecutionTypes::Lookup(this->SortedValuesMap.PrepareForInput(device, token), + this->Offsets.PrepareForInput(device, token), + this->Counts.PrepareForInput(device, token)); + } + + template + VTKM_CONT VTKM_DEPRECATED(1.6, "PrepareForInput now requires a vtkm::cont::Token object.") + typename ExecutionTypes::Lookup PrepareForInput(Device device) const + { + vtkm::cont::Token token; + return this->PrepareForInput(device, token); + } + + VTKM_CONT + bool operator==(const vtkm::worklet::internal::KeysBase& other) const + { + return ((this->SortedValuesMap == other.SortedValuesMap) && (this->Offsets == other.Offsets) && + (this->Counts == other.Counts)); + } + + VTKM_CONT + bool operator!=(const vtkm::worklet::internal::KeysBase& other) const + { + return !(*this == other); + } + +protected: + KeysBase() = default; + + vtkm::cont::ArrayHandle SortedValuesMap; + vtkm::cont::ArrayHandle Offsets; + vtkm::cont::ArrayHandle Counts; +}; + +} // namespace internal + /// Select the type of sort for BuildArrays calls. Unstable sorting is faster /// but will not produce consistent ordering for equal keys. Stable sorting /// is slower, but keeps equal keys in their original order. @@ -67,7 +146,7 @@ enum class KeysSortType /// creating a different \c Keys structure for each \c Invoke. /// template -class VTKM_ALWAYS_EXPORT Keys +class VTKM_ALWAYS_EXPORT Keys : public internal::KeysBase { public: using KeyType = T; @@ -110,24 +189,9 @@ public: KeysSortType sort, vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny()); - VTKM_CONT - vtkm::Id GetInputRange() const { return this->UniqueKeys.GetNumberOfValues(); } - VTKM_CONT KeyArrayHandleType GetUniqueKeys() const { return this->UniqueKeys; } - VTKM_CONT - vtkm::cont::ArrayHandle GetSortedValuesMap() const { return this->SortedValuesMap; } - - VTKM_CONT - vtkm::cont::ArrayHandle GetOffsets() const { return this->Offsets; } - - VTKM_CONT - vtkm::cont::ArrayHandle GetCounts() const { return this->Counts; } - - VTKM_CONT - vtkm::Id GetNumberOfValues() const { return this->SortedValuesMap.GetNumberOfValues(); } - template struct ExecutionTypes { @@ -173,9 +237,6 @@ public: private: /// @cond NONE KeyArrayHandleType UniqueKeys; - vtkm::cont::ArrayHandle SortedValuesMap; - vtkm::cont::ArrayHandle Offsets; - vtkm::cont::ArrayHandle Counts; template VTKM_CONT void BuildArraysInternal(KeyArrayType& keys, vtkm::cont::DeviceAdapterId device); @@ -189,6 +250,9 @@ private: template VTKM_CONT Keys::Keys() = default; +namespace internal +{ + template inline auto SchedulingRange(const vtkm::worklet::Keys& inputDomain) -> decltype(inputDomain.GetInputRange()) @@ -202,6 +266,19 @@ inline auto SchedulingRange(const vtkm::worklet::Keys* const inputDomai { return inputDomain->GetInputRange(); } + +inline auto SchedulingRange(const vtkm::worklet::internal::KeysBase& inputDomain) + -> decltype(inputDomain.GetInputRange()) +{ + return inputDomain.GetInputRange(); +} + +inline auto SchedulingRange(const vtkm::worklet::internal::KeysBase* const inputDomain) + -> decltype(inputDomain->GetInputRange()) +{ + return inputDomain->GetInputRange(); +} +} // namespace internal } } // namespace vtkm::worklet @@ -218,15 +295,16 @@ namespace arg { template -struct TypeCheck> +struct TypeCheck { - static constexpr bool value = true; + static constexpr bool value = + std::is_base_of::type>::value; }; template -struct Transport, Device> +struct Transport { - using ContObjectType = vtkm::worklet::Keys; + using ContObjectType = KeyType; using ExecObjectType = typename ContObjectType::template ExecutionTypes::Lookup; VTKM_CONT @@ -264,9 +342,8 @@ struct Transport::PortalConst; - template VTKM_CONT ExecObjectType operator()(const ContObjectType& object, - const vtkm::worklet::Keys& keys, + const vtkm::worklet::internal::KeysBase& keys, vtkm::Id, vtkm::Id, vtkm::cont::Token& token) const @@ -300,9 +377,8 @@ struct Transport::Portal; - template VTKM_CONT ExecObjectType operator()(ContObjectType object, - const vtkm::worklet::Keys& keys, + const vtkm::worklet::internal::KeysBase& keys, vtkm::Id, vtkm::Id, vtkm::cont::Token& token) const @@ -336,9 +412,8 @@ struct Transport::Portal; - template VTKM_CONT ExecObjectType operator()(ContObjectType object, - const vtkm::worklet::Keys& keys, + const vtkm::worklet::internal::KeysBase& keys, vtkm::Id, vtkm::Id, vtkm::cont::Token& token) const diff --git a/vtkm/worklet/testing/UnitTestWorkletReduceByKey.cxx b/vtkm/worklet/testing/UnitTestWorkletReduceByKey.cxx index d8a4e7187..98a82743d 100644 --- a/vtkm/worklet/testing/UnitTestWorkletReduceByKey.cxx +++ b/vtkm/worklet/testing/UnitTestWorkletReduceByKey.cxx @@ -128,6 +128,9 @@ void TryKeyType(KeyType) vtkm::cont::ArrayCopy(keyArray, sortedKeys); vtkm::worklet::Keys keys(sortedKeys); + vtkm::cont::printSummary_ArrayHandle(keys.GetUniqueKeys(), std::cout); + vtkm::cont::printSummary_ArrayHandle(keys.GetOffsets(), std::cout); + vtkm::cont::printSummary_ArrayHandle(keys.GetCounts(), std::cout); vtkm::cont::ArrayHandle valuesToModify; valuesToModify.Allocate(ARRAY_SIZE);