From 2a41428fe411c3489958296effcde8e0c8c7a8ab Mon Sep 17 00:00:00 2001 From: Kenneth Moreland Date: Tue, 16 Feb 2021 14:22:55 -0700 Subject: [PATCH] Add implementation of ArrayRangeCompute for UnknownArrayHandle This allows you to easily compute the range of any ArrayHandle, even if you don't know the type. A unit test for ArrayRangeCompute was also added. --- docs/changelog/array-range-compute-unknown.md | 18 ++ vtkm/TypeTraits.h | 1 + vtkm/cont/ArrayRangeCompute.cxx | 130 +++++++++- vtkm/cont/ArrayRangeCompute.h | 31 +-- vtkm/cont/testing/CMakeLists.txt | 1 + .../testing/UnitTestArrayRangeCompute.cxx | 236 ++++++++++++++++++ 6 files changed, 395 insertions(+), 22 deletions(-) create mode 100644 docs/changelog/array-range-compute-unknown.md create mode 100644 vtkm/cont/testing/UnitTestArrayRangeCompute.cxx diff --git a/docs/changelog/array-range-compute-unknown.md b/docs/changelog/array-range-compute-unknown.md new file mode 100644 index 000000000..de8c5a2fa --- /dev/null +++ b/docs/changelog/array-range-compute-unknown.md @@ -0,0 +1,18 @@ +# `ArrayRangeCompute` works on any array type without compiling device code + +Originally, `ArrayRangeCompute` required you to know specifically the +`ArrayHandle` type (value type and storage type) and to compile using any +device compiler. The method is changed to include only overloads that have +precompiled versions of `ArrayRangeCompute`. + +Additionally, an `ArrayRangeCompute` overload that takes an +`UnknownArrayHandle` has been added. In addition to allowing you to compute +the range of arrays of unknown types, this implementation of +`ArrayRangeCompute` serves as a fallback for `ArrayHandle` types that are +not otherwise explicitly supported. + +If you really want to make sure that you compute the range directly on an +`ArrayHandle` of a particular type, you can include +`ArrayRangeComputeTemplate.h`, which contains a templated overload of +`ArrayRangeCompute` that directly computes the range of an `ArrayHandle`. +Including this header requires compiling for device code. diff --git a/vtkm/TypeTraits.h b/vtkm/TypeTraits.h index 84f7b66f9..f321d3bf9 100644 --- a/vtkm/TypeTraits.h +++ b/vtkm/TypeTraits.h @@ -110,6 +110,7 @@ struct TypeTraits : TypeTraits VTKM_BASIC_REAL_TYPE(float) VTKM_BASIC_REAL_TYPE(double) +VTKM_BASIC_INTEGER_TYPE(bool) VTKM_BASIC_INTEGER_TYPE(char) VTKM_BASIC_INTEGER_TYPE(signed char) VTKM_BASIC_INTEGER_TYPE(unsigned char) diff --git a/vtkm/cont/ArrayRangeCompute.cxx b/vtkm/cont/ArrayRangeCompute.cxx index ba6b6a5d7..073b1477d 100644 --- a/vtkm/cont/ArrayRangeCompute.cxx +++ b/vtkm/cont/ArrayRangeCompute.cxx @@ -10,6 +10,8 @@ #include +#include + namespace vtkm { namespace cont @@ -78,6 +80,8 @@ VTKM_ARRAY_RANGE_COMPUTE_IMPL_ALL_VEC(2, vtkm::cont::StorageTagSOA); VTKM_ARRAY_RANGE_COMPUTE_IMPL_ALL_VEC(3, vtkm::cont::StorageTagSOA); VTKM_ARRAY_RANGE_COMPUTE_IMPL_ALL_VEC(4, vtkm::cont::StorageTagSOA); +VTKM_ARRAY_RANGE_COMPUTE_IMPL_ALL_SCALAR_T(vtkm::cont::StorageTagStride); + VTKM_ARRAY_RANGE_COMPUTE_IMPL_VEC(vtkm::Float32, 3, vtkm::cont::StorageTagXGCCoordinates); VTKM_ARRAY_RANGE_COMPUTE_IMPL_VEC(vtkm::Float64, 3, vtkm::cont::StorageTagXGCCoordinates); @@ -111,7 +115,7 @@ vtkm::cont::ArrayHandle ArrayRangeCompute( return rangeArray; } -VTKM_CONT vtkm::cont::ArrayHandle ArrayRangeCompute( +vtkm::cont::ArrayHandle ArrayRangeCompute( const vtkm::cont::ArrayHandle& input, vtkm::cont::DeviceAdapterId) { @@ -120,5 +124,129 @@ VTKM_CONT vtkm::cont::ArrayHandle ArrayRangeCompute( result.WritePortal().Set(0, vtkm::Range(0, input.GetNumberOfValues() - 1)); return result; } + +namespace +{ + +using AllScalars = vtkm::TypeListBaseC; + +template +struct VecTransform +{ + template + using type = vtkm::Vec; +}; + +template +using AllVecOfSize = vtkm::ListTransform::template type>; + +using AllVec = vtkm::ListAppend, AllVecOfSize<3>, AllVecOfSize<4>>; + +using AllTypes = vtkm::ListAppend; + +struct ComputeRangeFunctor +{ + // Used with UnknownArrayHandle::CastAndCallForTypes + template + void operator()(const vtkm::cont::ArrayHandle& array, + vtkm::cont::DeviceAdapterId device, + vtkm::cont::ArrayHandle& ranges) const + { + ranges = vtkm::cont::ArrayRangeCompute(array, device); + } + + // Used with vtkm::ListForEach to get components + template + void operator()(T, + const vtkm::cont::UnknownArrayHandle& array, + vtkm::cont::DeviceAdapterId device, + vtkm::cont::ArrayHandle& ranges, + bool& success) const + { + if (!success && array.IsBaseComponentType()) + { + vtkm::IdComponent numComponents = array.GetNumberOfComponentsFlat(); + ranges.Allocate(numComponents); + auto rangePortal = ranges.WritePortal(); + for (vtkm::IdComponent componentI = 0; componentI < numComponents; ++componentI) + { + vtkm::cont::ArrayHandleStride componentArray = array.ExtractComponent(componentI); + vtkm::cont::ArrayHandle componentRange = + vtkm::cont::ArrayRangeCompute(componentArray, device); + rangePortal.Set(componentI, componentRange.ReadPortal().Get(0)); + } + success = true; + } + } +}; + +template +vtkm::cont::ArrayHandle ComputeForStorage(const vtkm::cont::UnknownArrayHandle& array, + vtkm::cont::DeviceAdapterId device) +{ + vtkm::cont::ArrayHandle ranges; + array.CastAndCallForTypes>(ComputeRangeFunctor{}, device, ranges); + return ranges; +} + +} // anonymous namespace + +vtkm::cont::ArrayHandle ArrayRangeCompute(const vtkm::cont::UnknownArrayHandle& array, + vtkm::cont::DeviceAdapterId device) +{ + // First, try fast-paths of precompiled array types common(ish) in fields. + try + { + if (array.IsStorageType()) + { + return ComputeForStorage(array, device); + } + if (array.IsStorageType()) + { + return ComputeForStorage(array, device); + } + if (array.IsStorageType()) + { + return ComputeForStorage( + array, device); + } + if (array.IsStorageType()) + { + vtkm::cont::ArrayHandleUniformPointCoordinates uniformPoints; + array.AsArrayHandle(uniformPoints); + return vtkm::cont::ArrayRangeCompute(uniformPoints, device); + } + using CartesianProductStorage = + vtkm::cont::StorageTagCartesianProduct; + if (array.IsStorageType()) + { + return ComputeForStorage(array, device); + } + if (array.IsStorageType()) + { + return ComputeForStorage(array, device); + } + if (array.IsStorageType()) + { + return ComputeForStorage(array, device); + } + if (array.IsStorageType()) + { + return ArrayRangeCompute(array.AsArrayHandle(), device); + } + } + catch (vtkm::cont::ErrorBadValue&) + { + // If a cast/call failed, try falling back to a more general implementation. + } + + vtkm::cont::ArrayHandle ranges; + bool success = false; + vtkm::ListForEach(ComputeRangeFunctor{}, AllScalars{}, array, device, ranges, success); + return ranges; +} + } } // namespace vtkm::cont diff --git a/vtkm/cont/ArrayRangeCompute.h b/vtkm/cont/ArrayRangeCompute.h index b0389d7c9..dfbc01f80 100644 --- a/vtkm/cont/ArrayRangeCompute.h +++ b/vtkm/cont/ArrayRangeCompute.h @@ -20,9 +20,11 @@ #include #include #include +#include #include #include #include +#include namespace vtkm { @@ -51,6 +53,10 @@ namespace cont /// that will compile for any `ArrayHandle` type not already handled. /// +VTKM_CONT_EXPORT vtkm::cont::ArrayHandle ArrayRangeCompute( + const vtkm::cont::UnknownArrayHandle& array, + vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny{}); + #define VTK_M_ARRAY_RANGE_COMPUTE_EXPORT_T(T, Storage) \ VTKM_CONT_EXPORT \ VTKM_CONT \ @@ -104,6 +110,8 @@ VTK_M_ARRAY_RANGE_COMPUTE_EXPORT_ALL_VEC(2, vtkm::cont::StorageTagSOA); VTK_M_ARRAY_RANGE_COMPUTE_EXPORT_ALL_VEC(3, vtkm::cont::StorageTagSOA); VTK_M_ARRAY_RANGE_COMPUTE_EXPORT_ALL_VEC(4, vtkm::cont::StorageTagSOA); +VTK_M_ARRAY_RANGE_COMPUTE_EXPORT_ALL_SCALAR_T(vtkm::cont::StorageTagStride); + VTK_M_ARRAY_RANGE_COMPUTE_EXPORT_VEC(vtkm::Float32, 3, vtkm::cont::StorageTagXGCCoordinates); VTK_M_ARRAY_RANGE_COMPUTE_EXPORT_VEC(vtkm::Float64, 3, vtkm::cont::StorageTagXGCCoordinates); @@ -117,25 +125,6 @@ VTKM_CONT_EXPORT VTKM_CONT vtkm::cont::ArrayHandle ArrayRangeComput vtkm::cont::ArrayHandleUniformPointCoordinates::StorageTag>& array, vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny()); -// Implementation of composite vectors -VTKM_CONT_EXPORT -VTKM_CONT -vtkm::cont::ArrayHandle ArrayRangeCompute( - const vtkm::cont::ArrayHandle, - vtkm::cont::ArrayHandle, - vtkm::cont::ArrayHandle>::StorageTag>& input, - vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny()); - -VTKM_CONT_EXPORT VTKM_CONT vtkm::cont::ArrayHandle ArrayRangeCompute( - const vtkm::cont::ArrayHandle, - vtkm::cont::ArrayHandle, - vtkm::cont::ArrayHandle>::StorageTag>& input, - vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny()); - // Implementation of cartesian products template VTKM_CONT inline vtkm::cont::ArrayHandle ArrayRangeCompute( @@ -204,7 +193,7 @@ VTKM_CONT inline vtkm::cont::ArrayHandle ArrayRangeCompute( if (portal.GetNumberOfValues() > 0) { T first = input.ReadPortal().Get(0); - T last = input.ReadPortal().Get(portal.GetNumberOfValues() - 1); + T last = input.ReadPortal().Get(input.GetNumberOfValues() - 1); for (vtkm::IdComponent cIndex = 0; cIndex < Traits::NUM_COMPONENTS; ++cIndex) { auto firstComponent = Traits::GetComponent(first, cIndex); @@ -226,7 +215,7 @@ VTKM_CONT inline vtkm::cont::ArrayHandle ArrayRangeCompute( } // Implementation of index arrays -VTKM_CONT vtkm::cont::ArrayHandle ArrayRangeCompute( +VTKM_CONT_EXPORT vtkm::cont::ArrayHandle ArrayRangeCompute( const vtkm::cont::ArrayHandle& input, vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny{}); ///@} diff --git a/vtkm/cont/testing/CMakeLists.txt b/vtkm/cont/testing/CMakeLists.txt index 411e25441..b70b8d9b4 100644 --- a/vtkm/cont/testing/CMakeLists.txt +++ b/vtkm/cont/testing/CMakeLists.txt @@ -55,6 +55,7 @@ set(unit_tests UnitTestArrayHandleVirtual.cxx UnitTestArrayHandleXGCCoordinates.cxx UnitTestArrayPortalToIterators.cxx + UnitTestArrayRangeCompute.cxx UnitTestCellLocatorChooser.cxx UnitTestCellLocatorGeneral.cxx UnitTestCellSet.cxx diff --git a/vtkm/cont/testing/UnitTestArrayRangeCompute.cxx b/vtkm/cont/testing/UnitTestArrayRangeCompute.cxx new file mode 100644 index 000000000..c5d858d27 --- /dev/null +++ b/vtkm/cont/testing/UnitTestArrayRangeCompute.cxx @@ -0,0 +1,236 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +//============================================================================ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace +{ + +constexpr vtkm::Id ARRAY_SIZE = 20; + +template +void CheckRange(const vtkm::cont::ArrayHandle& array, bool checkUnknown = true) +{ + using Traits = vtkm::VecTraits; + vtkm::IdComponent numComponents = Traits::NUM_COMPONENTS; + + vtkm::cont::ArrayHandle computedRangeArray = vtkm::cont::ArrayRangeCompute(array); + VTKM_TEST_ASSERT(computedRangeArray.GetNumberOfValues() == numComponents); + auto computedRangePortal = computedRangeArray.ReadPortal(); + + auto portal = array.ReadPortal(); + for (vtkm::IdComponent component = 0; component < numComponents; ++component) + { + vtkm::Range computedRange = computedRangePortal.Get(component); + vtkm::Range expectedRange; + for (vtkm::Id index = 0; index < portal.GetNumberOfValues(); ++index) + { + T value = portal.Get(index); + expectedRange.Include(Traits::GetComponent(value, component)); + } + VTKM_TEST_ASSERT(!vtkm::IsNan(computedRange.Min)); + VTKM_TEST_ASSERT(!vtkm::IsNan(computedRange.Max)); + VTKM_TEST_ASSERT(test_equal(expectedRange, computedRange)); + } + + if (checkUnknown) + { + computedRangeArray = vtkm::cont::ArrayRangeCompute(vtkm::cont::UnknownArrayHandle{ array }); + VTKM_TEST_ASSERT(computedRangeArray.GetNumberOfValues() == numComponents); + computedRangePortal = computedRangeArray.ReadPortal(); + + portal = array.ReadPortal(); + for (vtkm::IdComponent component = 0; component < numComponents; ++component) + { + vtkm::Range computedRange = computedRangePortal.Get(component); + vtkm::Range expectedRange; + for (vtkm::Id index = 0; index < portal.GetNumberOfValues(); ++index) + { + T value = portal.Get(index); + expectedRange.Include(Traits::GetComponent(value, component)); + } + VTKM_TEST_ASSERT(!vtkm::IsNan(computedRange.Min)); + VTKM_TEST_ASSERT(!vtkm::IsNan(computedRange.Max)); + VTKM_TEST_ASSERT(test_equal(expectedRange, computedRange)); + } + } +} + +template +void FillArray(vtkm::cont::ArrayHandle& array) +{ + using Traits = vtkm::VecTraits; + vtkm::IdComponent numComponents = Traits::NUM_COMPONENTS; + + vtkm::cont::ArrayCopy(vtkm::cont::make_ArrayHandleConstant(T{}, ARRAY_SIZE), array); + + for (vtkm::IdComponent component = 0; component < numComponents; ++component) + { + vtkm::cont::ArrayHandleRandomUniformReal randomArray(ARRAY_SIZE); + auto dest = vtkm::cont::make_ArrayHandleExtractComponent(array, component); + vtkm::cont::ArrayCopy(randomArray, dest); + } +} + +template +void TestBasicArray() +{ + std::cout << "Checking basic array" << std::endl; + vtkm::cont::ArrayHandleBasic array; + FillArray(array); + CheckRange(array); +} + +template +void TestSOAArray(vtkm::TypeTraitsVectorTag) +{ + std::cout << "Checking SOA array" << std::endl; + vtkm::cont::ArrayHandleSOA array; + FillArray(array); + CheckRange(array); +} + +template +void TestSOAArray(vtkm::TypeTraitsScalarTag) +{ + // Skip test. +} + +template +void TestStrideArray(vtkm::TypeTraitsScalarTag) +{ + std::cout << "Checking stride array" << std::endl; + vtkm::cont::ArrayHandleBasic array; + FillArray(array); + CheckRange(vtkm::cont::ArrayHandleStride(array, ARRAY_SIZE / 2, 2, 1)); +} + +template +void TestCartesianProduct(vtkm::TypeTraitsScalarTag) +{ + std::cout << "Checking Cartesian product" << std::endl; + + vtkm::cont::ArrayHandleBasic array0; + FillArray(array0); + vtkm::cont::ArrayHandleBasic array1; + FillArray(array1); + vtkm::cont::ArrayHandleBasic array2; + FillArray(array2); + + CheckRange(vtkm::cont::make_ArrayHandleCartesianProduct(array0, array1, array2)); +} + +template +void TestCartesianProduct(vtkm::TypeTraitsVectorTag) +{ + // Skip test. +} + +template +void TestComposite(vtkm::TypeTraitsScalarTag) +{ + std::cout << "Checking composite vector" << std::endl; + + vtkm::cont::ArrayHandleBasic array0; + FillArray(array0); + vtkm::cont::ArrayHandleBasic array1; + FillArray(array1); + vtkm::cont::ArrayHandleBasic array2; + FillArray(array2); + + CheckRange(vtkm::cont::make_ArrayHandleCompositeVector(array0, array1, array2)); +} + +template +void TestComposite(vtkm::TypeTraitsVectorTag) +{ + // Skip test. +} + +template +void TestConstant() +{ + std::cout << "Checking constant array" << std::endl; + CheckRange(vtkm::cont::make_ArrayHandleConstant(TestValue(10, T{}), ARRAY_SIZE)); +} + +template +void TestCounting(std::true_type vtkmNotUsed(is_signed)) +{ + std::cout << "Checking counting array" << std::endl; + CheckRange(vtkm::cont::make_ArrayHandleCounting(TestValue(10, T{}), T{ 1 }, ARRAY_SIZE)); + + std::cout << "Checking counting backward array" << std::endl; + CheckRange(vtkm::cont::make_ArrayHandleCounting(TestValue(10, T{}), T{ -1 }, ARRAY_SIZE)); +} + +template +void TestCounting(std::false_type vtkmNotUsed(is_signed)) +{ + // Skip test +} + +void TestIndex() +{ + std::cout << "Checking index array" << std::endl; + CheckRange(vtkm::cont::make_ArrayHandleIndex(ARRAY_SIZE)); +} + +void TestUniformPointCoords() +{ + std::cout << "Checking uniform point coordinates" << std::endl; + CheckRange( + vtkm::cont::ArrayHandleUniformPointCoordinates(vtkm::Id3(ARRAY_SIZE, ARRAY_SIZE, ARRAY_SIZE))); +} + +struct DoTestFunctor +{ + template + void operator()(T) const + { + TestBasicArray(); + TestSOAArray(typename vtkm::TypeTraits::DimensionalityTag{}); + TestCartesianProduct(typename vtkm::TypeTraits::DimensionalityTag{}); + TestComposite(typename vtkm::TypeTraits::DimensionalityTag{}); + TestConstant(); + TestCounting(typename std::is_signed::ComponentType>::type{}); + } +}; + +void DoTest() +{ + vtkm::testing::Testing::TryTypes(DoTestFunctor{}); + + std::cout << "*** Specific arrays *****************" << std::endl; + TestIndex(); + TestUniformPointCoords(); +} + +} // anonymous namespace + +int UnitTestArrayRangeCompute(int argc, char* argv[]) +{ + return vtkm::cont::testing::Testing::Run(DoTest, argc, argv); +}