Change PointAverage to work on fields of any type

The previous version of the `PointAverage` filter used a float fallback
to handle most array types. The problem with this approach other than
converting field types perhaps unexpectantly is that it does not work
with every `Vec` size. This change uses the extract by component feature
of `UnknownArrayHandle` to handle every array type.

To implement this change the `PointAverage` worklet had to be changed to
handle recombined vecs. This change resulted in a feature degridation
where it can no longer be compiled for inputs of incompatible `Vec`
sizes. This feature dates back to when worklets like this were exposed
in the interface. This worklet class is now hidden away from the exposed
interface, so this degredation should not affect end users. There are
some unit tests that use this worklet to test other features, and these
had to be updated.
This commit is contained in:
Kenneth Moreland 2023-02-02 16:41:50 -05:00
parent 0c13917c1e
commit 634847ce20
5 changed files with 35 additions and 51 deletions

@ -18,3 +18,4 @@ updated.
* `ClipWithImplicitFunction` * `ClipWithImplicitFunction`
* `Contour` * `Contour`
* `MIRFilter` * `MIRFilter`
* `PointAverage`

@ -29,25 +29,23 @@ vtkm::cont::DataSet PointAverage::DoExecute(const vtkm::cont::DataSet& input)
} }
vtkm::cont::UnknownCellSet cellSet = input.GetCellSet(); vtkm::cont::UnknownCellSet cellSet = input.GetCellSet();
vtkm::cont::UnknownArrayHandle outArray; vtkm::cont::UnknownArrayHandle inArray = field.GetData();
vtkm::cont::UnknownArrayHandle outArray = inArray.NewInstanceBasic();
auto resolveType = [&](const auto& concrete) { auto resolveType = [&](const auto& concrete) {
using T = typename std::decay_t<decltype(concrete)>::ValueType; using T = typename std::decay_t<decltype(concrete)>::ValueType::ComponentType;
auto result = outArray.ExtractArrayFromComponents<T>();
using SupportedCellSets = using SupportedCellSets =
vtkm::ListAppend<vtkm::List<vtkm::cont::CellSetExtrude>, VTKM_DEFAULT_CELL_SET_LIST>; vtkm::ListAppend<vtkm::List<vtkm::cont::CellSetExtrude>, VTKM_DEFAULT_CELL_SET_LIST>;
vtkm::cont::ArrayHandle<T> result;
this->Invoke(vtkm::worklet::PointAverage{}, this->Invoke(vtkm::worklet::PointAverage{},
cellSet.ResetCellSetList<SupportedCellSets>(), cellSet.ResetCellSetList<SupportedCellSets>(),
concrete, concrete,
result); result);
outArray = result;
}; };
// TODO: Do we need to deal with XCG storage type explicitly? // TODO: Do we need to deal with XCG storage type (vtkm::cont::ArrayHandleXGCCoordinates)
// using AdditionalFieldStorage = vtkm::List<vtkm::cont::StorageTagXGCCoordinates>; // explicitly? Extracting from that is slow.
field.GetData() inArray.CastAndCallWithExtractedArray(resolveType);
.CastAndCallForTypesWithFloatFallback<vtkm::TypeListField, VTKM_DEFAULT_STORAGE_LIST>(
resolveType);
std::string outputName = this->GetOutputFieldName(); std::string outputName = this->GetOutputFieldName();
if (outputName.empty()) if (outputName.empty())

@ -37,48 +37,26 @@ public:
OutType& average) const OutType& average) const
{ {
using CellValueType = typename CellValueVecType::ComponentType; using CellValueType = typename CellValueVecType::ComponentType;
using InVecSize =
std::integral_constant<vtkm::IdComponent, vtkm::VecTraits<CellValueType>::NUM_COMPONENTS>;
using OutVecSize =
std::integral_constant<vtkm::IdComponent, vtkm::VecTraits<OutType>::NUM_COMPONENTS>;
using SameLengthVectors = typename std::is_same<InVecSize, OutVecSize>::type;
VTKM_ASSERT(vtkm::VecTraits<CellValueType>::GetNumberOfComponents(cellValues[0]) ==
vtkm::VecTraits<OutType>::GetNumberOfComponents(average));
average = vtkm::TypeTraits<OutType>::ZeroInitialization(); average = cellValues[0];
if (numCells != 0)
{
this->DoAverage(numCells, cellValues, average, SameLengthVectors());
}
}
private:
template <typename CellValueVecType, typename OutType>
VTKM_EXEC void DoAverage(const vtkm::IdComponent& numCells,
const CellValueVecType& cellValues,
OutType& average,
std::true_type) const
{
using OutComponentType = typename vtkm::VecTraits<OutType>::ComponentType;
OutType sum = OutType(cellValues[0]);
for (vtkm::IdComponent cellIndex = 1; cellIndex < numCells; ++cellIndex) for (vtkm::IdComponent cellIndex = 1; cellIndex < numCells; ++cellIndex)
{ {
// OutType constructor is for when OutType is a Vec. average += cellValues[cellIndex];
// static_cast is for when OutType is a small int that gets promoted to int32.
sum = static_cast<OutType>(sum + OutType(cellValues[cellIndex]));
} }
// OutType constructor is for when OutType is a Vec. using VTraits = vtkm::VecTraits<OutType>;
// static_cast is for when OutType is a small int that gets promoted to int32. using OutComponentType = typename vtkm::VecTraits<OutType>::ComponentType;
average = static_cast<OutType>(sum / OutType(static_cast<OutComponentType>(numCells))); const vtkm::IdComponent numComponents = VTraits::GetNumberOfComponents(average);
} for (vtkm::IdComponent compIndex = 0; compIndex < numComponents; ++compIndex)
{
template <typename CellValueVecType, typename OutType> VTraits::SetComponent(
VTKM_EXEC void DoAverage(const vtkm::IdComponent& vtkmNotUsed(numCells), average,
const CellValueVecType& vtkmNotUsed(cellValues), compIndex,
OutType& vtkmNotUsed(average), static_cast<OutComponentType>(VTraits::GetComponent(average, compIndex) / numCells));
std::false_type) const }
{
this->RaiseError("PointAverage called with mismatched Vec sizes for PointAverage.");
} }
}; };
} }

@ -151,11 +151,13 @@ static void TestAvgCellToPoint()
vtkm::cont::testing::MakeTestDataSet testDataSet; vtkm::cont::testing::MakeTestDataSet testDataSet;
vtkm::cont::DataSet dataSet = testDataSet.Make3DExplicitDataSet1(); vtkm::cont::DataSet dataSet = testDataSet.Make3DExplicitDataSet1();
auto field = dataSet.GetField("cellvar"); auto field = dataSet.GetField("cellvar");
vtkm::cont::ArrayHandle<vtkm::Float32> inArray;
field.GetData().AsArrayHandle(inArray);
vtkm::cont::ArrayHandle<vtkm::Float32> result; vtkm::cont::ArrayHandle<vtkm::Float32> result;
vtkm::worklet::DispatcherMapTopology<vtkm::worklet::PointAverage> dispatcher; vtkm::worklet::DispatcherMapTopology<vtkm::worklet::PointAverage> dispatcher;
dispatcher.Invoke(dataSet.GetCellSet(), &field, result); dispatcher.Invoke(dataSet.GetCellSet(), &inArray, result);
std::cout << "Make sure we got the right answer." << std::endl; std::cout << "Make sure we got the right answer." << std::endl;
VTKM_TEST_ASSERT(test_equal(result.ReadPortal().Get(0), 100.1f), VTKM_TEST_ASSERT(test_equal(result.ReadPortal().Get(0), 100.1f),
@ -167,9 +169,12 @@ static void TestAvgCellToPoint()
bool exceptionThrown = false; bool exceptionThrown = false;
try try
{ {
dispatcher.Invoke(dataSet.GetCellSet(), dispatcher.Invoke(
dataSet.GetField("pointvar"), // should be cellvar dataSet.GetCellSet(),
result); dataSet.GetField("pointvar")
.GetData()
.AsArrayHandle<vtkm::cont::ArrayHandle<vtkm::Float32>>(), // should be cellvar
result);
} }
catch (vtkm::cont::ErrorBadValue& error) catch (vtkm::cont::ErrorBadValue& error)
{ {

@ -194,7 +194,7 @@ static void TestAvgCellToPoint()
// of the way we get cell indices. We need to make that // of the way we get cell indices. We need to make that
// part more flexible. // part more flexible.
dataSet.GetCellSet().ResetCellSetList(vtkm::cont::CellSetListStructured2D()), dataSet.GetCellSet().ResetCellSetList(vtkm::cont::CellSetListStructured2D()),
dataSet.GetField("cellvar"), dataSet.GetField("cellvar").GetData().AsArrayHandle<vtkm::cont::ArrayHandle<vtkm::Float32>>(),
result); result);
std::cout << "Make sure we got the right answer." << std::endl; std::cout << "Make sure we got the right answer." << std::endl;
@ -212,7 +212,9 @@ static void TestAvgCellToPoint()
// of the way we get cell indices. We need to make that // of the way we get cell indices. We need to make that
// part more flexible. // part more flexible.
dataSet.GetCellSet().ResetCellSetList(vtkm::cont::CellSetListStructured2D()), dataSet.GetCellSet().ResetCellSetList(vtkm::cont::CellSetListStructured2D()),
dataSet.GetField("pointvar"), // should be cellvar dataSet.GetField("pointvar")
.GetData()
.AsArrayHandle<vtkm::cont::ArrayHandle<vtkm::Float32>>(), // should be cellvar
result); result);
} }
catch (vtkm::cont::ErrorBadValue& error) catch (vtkm::cont::ErrorBadValue& error)