From 392d78135990182e694ef442ea51ba8fb01b081c Mon Sep 17 00:00:00 2001 From: Kenneth Moreland Date: Fri, 14 Jan 2022 14:04:18 -0700 Subject: [PATCH] Add ArrayCopy specialization for Counting and Permutation array This required adding another source file. --- vtkm/cont/ArrayCopy.cxx | 51 ++++++++++ vtkm/cont/ArrayCopy.h | 51 +++++++++- vtkm/cont/ArrayHandleCounting.h | 4 + vtkm/cont/CMakeLists.txt | 1 + vtkm/cont/internal/MapArrayPermutation.cxx | 16 ++- vtkm/cont/internal/MapArrayPermutation.h | 2 +- vtkm/cont/testing/UnitTestArrayCopy.cxx | 107 ++++++++++++++++----- 7 files changed, 201 insertions(+), 31 deletions(-) create mode 100644 vtkm/cont/ArrayCopy.cxx diff --git a/vtkm/cont/ArrayCopy.cxx b/vtkm/cont/ArrayCopy.cxx new file mode 100644 index 000000000..1ae0dcd9a --- /dev/null +++ b/vtkm/cont/ArrayCopy.cxx @@ -0,0 +1,51 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +//============================================================================ + +#include +#include + +namespace vtkm +{ +namespace cont +{ +namespace detail +{ + +void ArrayCopyConcreteSrc::CopyCountingFloat( + vtkm::FloatDefault start, + vtkm::FloatDefault step, + vtkm::Id size, + const vtkm::cont::UnknownArrayHandle& result) const +{ + if (result.IsBaseComponentType()) + { + auto outArray = result.ExtractComponent(0); + vtkm::cont::ArrayCopyDevice(vtkm::cont::make_ArrayHandleCounting(start, step, size), outArray); + } + else + { + vtkm::cont::ArrayHandle outArray; + outArray.Allocate(size); + CopyCountingFloat(start, step, size, outArray); + result.DeepCopyFrom(outArray); + } +} + +vtkm::cont::ArrayHandle ArrayCopyConcreteSrc::CopyCountingId( + const vtkm::cont::ArrayHandleCounting& source) const +{ + vtkm::cont::ArrayHandle destination; + vtkm::cont::ArrayCopyDevice(source, destination); + return destination; +} + +} +} +} // namespace vtkm::cont::detail diff --git a/vtkm/cont/ArrayCopy.h b/vtkm/cont/ArrayCopy.h index f19c29bf9..612a7e4f4 100644 --- a/vtkm/cont/ArrayCopy.h +++ b/vtkm/cont/ArrayCopy.h @@ -12,11 +12,14 @@ #include #include +#include #include #include #include #include +#include + #include #include @@ -258,6 +261,48 @@ struct ArrayCopyConcreteSrc } }; +// Special case for ArrayHandleCounting to be efficient. +template <> +struct VTKM_CONT_EXPORT ArrayCopyConcreteSrc +{ + template + void operator()(const vtkm::cont::ArrayHandle& source, + vtkm::cont::ArrayHandle& destination) const + { + vtkm::cont::ArrayHandleCounting countingSource = source; + T1 start = countingSource.GetStart(); + T1 step = countingSource.GetStep(); + vtkm::Id size = countingSource.GetNumberOfValues(); + destination.Allocate(size); + vtkm::cont::UnknownArrayHandle unknownDest = destination; + + using VTraits1 = vtkm::VecTraits; + using VTraits2 = vtkm::VecTraits; + for (vtkm::IdComponent comp = 0; comp < VTraits1::GetNumberOfComponents(start); ++comp) + { + this->CopyCountingFloat( + static_cast(VTraits1::GetComponent(start, comp)), + static_cast(VTraits1::GetComponent(step, comp)), + size, + unknownDest.ExtractComponent(comp)); + } + } + + void operator()(const vtkm::cont::ArrayHandle& source, + vtkm::cont::ArrayHandle& destination) const + { + destination = this->CopyCountingId(source); + } + +private: + void CopyCountingFloat(vtkm::FloatDefault start, + vtkm::FloatDefault step, + vtkm::Id size, + const vtkm::cont::UnknownArrayHandle& result) const; + vtkm::cont::ArrayHandle CopyCountingId( + const vtkm::cont::ArrayHandleCounting& source) const; +}; + // Special case for ArrayHandleConcatenate to be efficient template struct ArrayCopyConcreteSrc> @@ -280,10 +325,10 @@ struct ArrayCopyConcreteSrc> }; // Special case for ArrayHandlePermutation to be efficient -template -struct ArrayCopyConcreteSrc> +template +struct ArrayCopyConcreteSrc> { - using SourceStorageTag = vtkm::cont::StorageTagPermutation; + using SourceStorageTag = vtkm::cont::StorageTagPermutation; template void operator()(const vtkm::cont::ArrayHandle& source, vtkm::cont::ArrayHandle& destination) const diff --git a/vtkm/cont/ArrayHandleCounting.h b/vtkm/cont/ArrayHandleCounting.h index a5cc3a39d..7f6f08e0d 100644 --- a/vtkm/cont/ArrayHandleCounting.h +++ b/vtkm/cont/ArrayHandleCounting.h @@ -138,6 +138,10 @@ public: internal::ArrayPortalCounting(start, step, length))) { } + + VTKM_CONT CountingValueType GetStart() const { return this->ReadPortal().GetStart(); } + + VTKM_CONT CountingValueType GetStep() const { return this->ReadPortal().GetStep(); } }; /// A convenience function for creating an ArrayHandleCounting. It takes the diff --git a/vtkm/cont/CMakeLists.txt b/vtkm/cont/CMakeLists.txt index 99dc38c8f..e5f97f795 100644 --- a/vtkm/cont/CMakeLists.txt +++ b/vtkm/cont/CMakeLists.txt @@ -183,6 +183,7 @@ set(sources # This list of sources has code that uses devices and so might need to be # compiled with a device-specific compiler (like CUDA). set(device_sources + ArrayCopy.cxx ArrayGetValues.cxx ArrayRangeCompute.cxx CellLocatorBoundingIntervalHierarchy.cxx diff --git a/vtkm/cont/internal/MapArrayPermutation.cxx b/vtkm/cont/internal/MapArrayPermutation.cxx index d72000708..cb3512df0 100644 --- a/vtkm/cont/internal/MapArrayPermutation.cxx +++ b/vtkm/cont/internal/MapArrayPermutation.cxx @@ -11,6 +11,8 @@ #include #include +#include + #include @@ -48,9 +50,9 @@ struct MapPermutationWorklet : vtkm::worklet::WorkletMapField struct DoMapFieldPermutation { - template + template void operator()(const InputArrayType& input, - const vtkm::cont::ArrayHandle& permutation, + const PermutationArrayType& permutation, vtkm::cont::UnknownArrayHandle& output, vtkm::Float64 invalidValue) const { @@ -77,13 +79,19 @@ namespace internal vtkm::cont::UnknownArrayHandle MapArrayPermutation( const vtkm::cont::UnknownArrayHandle& inputArray, - const vtkm::cont::ArrayHandle& permutation, + const vtkm::cont::UnknownArrayHandle& permutation, vtkm::Float64 invalidValue) { + if (!permutation.IsBaseComponentType()) + { + throw vtkm::cont::ErrorBadType("Permutation array input to MapArrayPermutation must have " + "values of vtkm::Id. Reported type is " + + permutation.GetBaseComponentTypeName()); + } vtkm::cont::UnknownArrayHandle outputArray = inputArray.NewInstanceBasic(); outputArray.Allocate(permutation.GetNumberOfValues()); inputArray.CastAndCallWithExtractedArray( - DoMapFieldPermutation{}, permutation, outputArray, invalidValue); + DoMapFieldPermutation{}, permutation.ExtractComponent(0), outputArray, invalidValue); return outputArray; } diff --git a/vtkm/cont/internal/MapArrayPermutation.h b/vtkm/cont/internal/MapArrayPermutation.h index d571a3a99..393b7c5b1 100644 --- a/vtkm/cont/internal/MapArrayPermutation.h +++ b/vtkm/cont/internal/MapArrayPermutation.h @@ -26,7 +26,7 @@ namespace internal /// VTKM_CONT_EXPORT vtkm::cont::UnknownArrayHandle MapArrayPermutation( const vtkm::cont::UnknownArrayHandle& inputArray, - const vtkm::cont::ArrayHandle& permutation, + const vtkm::cont::UnknownArrayHandle& permutation, vtkm::Float64 invalidValue = vtkm::Nan64()); /// Used to map a permutation array. diff --git a/vtkm/cont/testing/UnitTestArrayCopy.cxx b/vtkm/cont/testing/UnitTestArrayCopy.cxx index daad8efe6..7a62ced82 100644 --- a/vtkm/cont/testing/UnitTestArrayCopy.cxx +++ b/vtkm/cont/testing/UnitTestArrayCopy.cxx @@ -9,12 +9,14 @@ //============================================================================ #include +#include #include #include #include #include #include +#include #include @@ -23,13 +25,43 @@ namespace static constexpr vtkm::Id ARRAY_SIZE = 10; -template -void TestValues(const RefArrayType& refArray, const TestArrayType& testArray) +vtkm::cont::UnknownArrayHandle MakeComparable(const vtkm::cont::UnknownArrayHandle& array, + std::false_type) +{ + return array; +} + +template +vtkm::cont::UnknownArrayHandle MakeComparable(const vtkm::cont::ArrayHandle& array, + std::true_type) +{ + return array; +} + +template +vtkm::cont::UnknownArrayHandle MakeComparable(const ArrayType& array, std::true_type) +{ + vtkm::cont::ArrayHandle simpleArray; + vtkm::cont::ArrayCopyDevice(array, simpleArray); + return simpleArray; +} + +void TestValuesImpl(const vtkm::cont::UnknownArrayHandle& refArray, + const vtkm::cont::UnknownArrayHandle& testArray) { auto result = test_equal_ArrayHandles(refArray, testArray); VTKM_TEST_ASSERT(result, result.GetMergedMessage()); } +template +void TestValues(const RefArrayType& refArray, const TestArrayType& testArray) +{ + TestValuesImpl( + MakeComparable(refArray, typename vtkm::cont::internal::ArrayHandleCheck::type{}), + MakeComparable(testArray, + typename vtkm::cont::internal::ArrayHandleCheck::type{})); +} + template vtkm::cont::ArrayHandle MakeInputArray() { @@ -44,18 +76,20 @@ void TryCopy() { VTKM_LOG_S(vtkm::cont::LogLevel::Info, "Trying type: " << vtkm::testing::TypeName::Name()); + using VTraits = vtkm::VecTraits; { std::cout << "implicit -> basic" << std::endl; vtkm::cont::ArrayHandleIndex input(ARRAY_SIZE); - vtkm::cont::ArrayHandle output; + vtkm::cont::ArrayHandle output; vtkm::cont::ArrayCopy(input, output); TestValues(input, output); } { std::cout << "basic -> basic" << std::endl; - vtkm::cont::ArrayHandle input = MakeInputArray(); + using SourceType = typename VTraits::template ReplaceComponentType; + vtkm::cont::ArrayHandle input = MakeInputArray(); vtkm::cont::ArrayHandle output; vtkm::cont::ArrayCopy(input, output); TestValues(input, output); @@ -90,30 +124,53 @@ void TryCopy() TestValues(input, output); } - using TypeList = vtkm::ListAppend>; - using StorageList = VTKM_DEFAULT_STORAGE_LIST; - using UnknownArray = vtkm::cont::UnknownArrayHandle; - using UncertainArray = vtkm::cont::UncertainArrayHandle; - { - std::cout << "unknown -> unknown" << std::endl; - UnknownArray input = MakeInputArray(); - UnknownArray output; - vtkm::cont::ArrayCopy(input, output); - TestValues(input, output); - } - - { - std::cout << "uncertain -> basic (same type)" << std::endl; - UncertainArray input = MakeInputArray(); + std::cout << "constant -> basic" << std::endl; + vtkm::cont::ArrayHandleConstant input(TestValue(2, ValueType{}), ARRAY_SIZE); vtkm::cont::ArrayHandle output; vtkm::cont::ArrayCopy(input, output); TestValues(input, output); } { - std::cout << "uncertain -> basic (different type)" << std::endl; - UncertainArray input = MakeInputArray(); + std::cout << "counting -> basic" << std::endl; + vtkm::cont::ArrayHandleCounting input(ValueType(-4), ValueType(3), ARRAY_SIZE); + vtkm::cont::ArrayHandle output; + vtkm::cont::ArrayCopy(input, output); + TestValues(input, output); + } + + { + std::cout << "permutation -> basic" << std::endl; + vtkm::cont::ArrayHandle indices; + vtkm::cont::ArrayCopy(vtkm::cont::make_ArrayHandleCounting(0, 2, ARRAY_SIZE / 2), + indices); + auto input = vtkm::cont::make_ArrayHandlePermutation(indices, MakeInputArray()); + vtkm::cont::ArrayHandle output; + vtkm::cont::ArrayCopy(input, output); + TestValues(input, output); + } + + { + std::cout << "unknown -> unknown" << std::endl; + vtkm::cont::UnknownArrayHandle input = MakeInputArray(); + vtkm::cont::UnknownArrayHandle output; + vtkm::cont::ArrayCopy(input, output); + TestValues(input, output); + } + + { + std::cout << "unknown -> basic (same type)" << std::endl; + vtkm::cont::UnknownArrayHandle input = MakeInputArray(); + vtkm::cont::ArrayHandle output; + vtkm::cont::ArrayCopy(input, output); + TestValues(input, output); + } + + { + std::cout << "unknown -> basic (different type)" << std::endl; + using SourceType = typename VTraits::template ReplaceComponentType; + vtkm::cont::UnknownArrayHandle input = MakeInputArray(); vtkm::cont::ArrayHandle output; vtkm::cont::ArrayCopy(input, output); TestValues(input, output); @@ -139,7 +196,8 @@ void TryCopy() { std::cout << "unknown.DeepCopyFrom(different type)" << std::endl; - vtkm::cont::ArrayHandle input = MakeInputArray(); + using SourceType = typename VTraits::template ReplaceComponentType; + vtkm::cont::ArrayHandle input = MakeInputArray(); vtkm::cont::ArrayHandle outputArray; vtkm::cont::UnknownArrayHandle(outputArray).DeepCopyFrom(input); TestValues(input, outputArray); @@ -166,7 +224,8 @@ void TryCopy() { std::cout << "unknown.CopyShallowIfPossible(different type)" << std::endl; - vtkm::cont::ArrayHandle input = MakeInputArray(); + using SourceType = typename VTraits::template ReplaceComponentType; + vtkm::cont::ArrayHandle input = MakeInputArray(); vtkm::cont::ArrayHandle outputArray; vtkm::cont::UnknownArrayHandle(outputArray).CopyShallowIfPossible(input); TestValues(input, outputArray); @@ -203,6 +262,8 @@ void TestArrayCopy() TryCopy(); TryCopy(); TryCopy(); + TryCopy(); + TryCopy(); TryArrayCopyShallowIfPossible(); }