mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-09-16 17:22:55 +00:00
Change CellAverage to work on fields of any type
The previous version of the `CellAverage` filter used a float fallback to handle most array types. The problem with this approach other than converting field types perhaps unexpectantly is that it does not work with every `Vec` size. This change uses the extract by component feature of `UnknownArrayHandle` to handle every array type. To implement this change the `CellAverage` worklet had to be changed to handle recombined vecs. This change resulted in a feature degridation where it can no longer be compiled for inputs of incompatible `Vec` sizes. This feature dates back to when worklets like this were exposed in the interface. This worklet class is now hidden away from the exposed interface, so this degredation should not affect end users. There are some unit tests that use this worklet to test other features, and these had to be updated.
This commit is contained in:
parent
5173f65176
commit
0c13917c1e
@ -13,6 +13,7 @@ to get data from any type of array. The following filters have been
|
|||||||
updated.
|
updated.
|
||||||
|
|
||||||
* `CleanGrid`
|
* `CleanGrid`
|
||||||
|
* `CellAverage`
|
||||||
* `ClipWithField`
|
* `ClipWithField`
|
||||||
* `ClipWithImplicitFunction`
|
* `ClipWithImplicitFunction`
|
||||||
* `Contour`
|
* `Contour`
|
||||||
|
@ -105,7 +105,10 @@ void TestDataSet_Explicit()
|
|||||||
//run a basic for-each topology algorithm on this
|
//run a basic for-each topology algorithm on this
|
||||||
vtkm::cont::ArrayHandle<vtkm::Float32> result;
|
vtkm::cont::ArrayHandle<vtkm::Float32> result;
|
||||||
vtkm::worklet::DispatcherMapTopology<vtkm::worklet::CellAverage> dispatcher;
|
vtkm::worklet::DispatcherMapTopology<vtkm::worklet::CellAverage> dispatcher;
|
||||||
dispatcher.Invoke(subset, dataSet.GetField("pointvar"), result);
|
dispatcher.Invoke(
|
||||||
|
subset,
|
||||||
|
dataSet.GetField("pointvar").GetData().AsArrayHandle<vtkm::cont::ArrayHandle<vtkm::Float32>>(),
|
||||||
|
result);
|
||||||
|
|
||||||
//iterate same cell 4 times
|
//iterate same cell 4 times
|
||||||
vtkm::Float32 expected[4] = { 30.1667f, 30.1667f, 30.1667f, 30.1667f };
|
vtkm::Float32 expected[4] = { 30.1667f, 30.1667f, 30.1667f, 30.1667f };
|
||||||
@ -139,7 +142,10 @@ void TestDataSet_Structured2D()
|
|||||||
//run a basic for-each topology algorithm on this
|
//run a basic for-each topology algorithm on this
|
||||||
vtkm::cont::ArrayHandle<vtkm::Float32> result;
|
vtkm::cont::ArrayHandle<vtkm::Float32> result;
|
||||||
vtkm::worklet::DispatcherMapTopology<vtkm::worklet::CellAverage> dispatcher;
|
vtkm::worklet::DispatcherMapTopology<vtkm::worklet::CellAverage> dispatcher;
|
||||||
dispatcher.Invoke(subset, dataSet.GetField("pointvar"), result);
|
dispatcher.Invoke(
|
||||||
|
subset,
|
||||||
|
dataSet.GetField("pointvar").GetData().AsArrayHandle<vtkm::cont::ArrayHandle<vtkm::Float32>>(),
|
||||||
|
result);
|
||||||
|
|
||||||
vtkm::Float32 expected[4] = { 40.1f, 40.1f, 40.1f, 40.1f };
|
vtkm::Float32 expected[4] = { 40.1f, 40.1f, 40.1f, 40.1f };
|
||||||
auto resultPortal = result.ReadPortal();
|
auto resultPortal = result.ReadPortal();
|
||||||
@ -172,7 +178,10 @@ void TestDataSet_Structured3D()
|
|||||||
//run a basic for-each topology algorithm on this
|
//run a basic for-each topology algorithm on this
|
||||||
vtkm::cont::ArrayHandle<vtkm::Float32> result;
|
vtkm::cont::ArrayHandle<vtkm::Float32> result;
|
||||||
vtkm::worklet::DispatcherMapTopology<vtkm::worklet::CellAverage> dispatcher;
|
vtkm::worklet::DispatcherMapTopology<vtkm::worklet::CellAverage> dispatcher;
|
||||||
dispatcher.Invoke(subset, dataSet.GetField("pointvar"), result);
|
dispatcher.Invoke(
|
||||||
|
subset,
|
||||||
|
dataSet.GetField("pointvar").GetData().AsArrayHandle<vtkm::cont::ArrayHandle<vtkm::Float32>>(),
|
||||||
|
result);
|
||||||
|
|
||||||
vtkm::Float32 expected[4] = { 70.2125f, 70.2125f, 70.2125f, 70.2125f };
|
vtkm::Float32 expected[4] = { 70.2125f, 70.2125f, 70.2125f, 70.2125f };
|
||||||
auto resultPortal = result.ReadPortal();
|
auto resultPortal = result.ReadPortal();
|
||||||
|
@ -28,17 +28,15 @@ vtkm::cont::DataSet CellAverage::DoExecute(const vtkm::cont::DataSet& input)
|
|||||||
}
|
}
|
||||||
|
|
||||||
vtkm::cont::UnknownCellSet inputCellSet = input.GetCellSet();
|
vtkm::cont::UnknownCellSet inputCellSet = input.GetCellSet();
|
||||||
vtkm::cont::UnknownArrayHandle outArray;
|
vtkm::cont::UnknownArrayHandle inArray = field.GetData();
|
||||||
|
vtkm::cont::UnknownArrayHandle outArray = inArray.NewInstanceBasic();
|
||||||
|
|
||||||
auto resolveType = [&](const auto& concrete) {
|
auto resolveType = [&](const auto& concrete) {
|
||||||
using T = typename std::decay_t<decltype(concrete)>::ValueType;
|
using T = typename std::decay_t<decltype(concrete)>::ValueType::ComponentType;
|
||||||
vtkm::cont::ArrayHandle<T> result;
|
auto result = outArray.ExtractArrayFromComponents<T>();
|
||||||
this->Invoke(vtkm::worklet::CellAverage{}, inputCellSet, concrete, result);
|
this->Invoke(vtkm::worklet::CellAverage{}, inputCellSet, concrete, result);
|
||||||
outArray = result;
|
|
||||||
};
|
};
|
||||||
field.GetData()
|
inArray.CastAndCallWithExtractedArray(resolveType);
|
||||||
.CastAndCallForTypesWithFloatFallback<vtkm::TypeListField, VTKM_DEFAULT_STORAGE_LIST>(
|
|
||||||
resolveType);
|
|
||||||
|
|
||||||
std::string outputName = this->GetOutputFieldName();
|
std::string outputName = this->GetOutputFieldName();
|
||||||
if (outputName.empty())
|
if (outputName.empty())
|
||||||
|
@ -35,43 +35,25 @@ public:
|
|||||||
{
|
{
|
||||||
using PointValueType = typename PointValueVecType::ComponentType;
|
using PointValueType = typename PointValueVecType::ComponentType;
|
||||||
|
|
||||||
using InVecSize =
|
VTKM_ASSERT(vtkm::VecTraits<PointValueType>::GetNumberOfComponents(pointValues[0]) ==
|
||||||
std::integral_constant<vtkm::IdComponent, vtkm::VecTraits<PointValueType>::NUM_COMPONENTS>;
|
vtkm::VecTraits<OutType>::GetNumberOfComponents(average));
|
||||||
using OutVecSize =
|
|
||||||
std::integral_constant<vtkm::IdComponent, vtkm::VecTraits<OutType>::NUM_COMPONENTS>;
|
|
||||||
using SameLengthVectors = typename std::is_same<InVecSize, OutVecSize>::type;
|
|
||||||
|
|
||||||
this->DoAverage(numPoints, pointValues, average, SameLengthVectors());
|
average = pointValues[0];
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
template <typename PointValueVecType, typename OutType>
|
|
||||||
VTKM_EXEC void DoAverage(const vtkm::IdComponent& numPoints,
|
|
||||||
const PointValueVecType& pointValues,
|
|
||||||
OutType& average,
|
|
||||||
std::true_type) const
|
|
||||||
{
|
|
||||||
using OutComponentType = typename vtkm::VecTraits<OutType>::ComponentType;
|
|
||||||
OutType sum = OutType(pointValues[0]);
|
|
||||||
for (vtkm::IdComponent pointIndex = 1; pointIndex < numPoints; ++pointIndex)
|
for (vtkm::IdComponent pointIndex = 1; pointIndex < numPoints; ++pointIndex)
|
||||||
{
|
{
|
||||||
// OutType constructor is for when OutType is a Vec.
|
average += pointValues[pointIndex];
|
||||||
// static_cast is for when OutType is a small int that gets promoted to int32.
|
|
||||||
sum = static_cast<OutType>(sum + OutType(pointValues[pointIndex]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OutType constructor is for when OutType is a Vec.
|
using VTraits = vtkm::VecTraits<OutType>;
|
||||||
// static_cast is for when OutType is a small int that gets promoted to int32.
|
using OutComponentType = typename VTraits::ComponentType;
|
||||||
average = static_cast<OutType>(sum / OutType(static_cast<OutComponentType>(numPoints)));
|
const vtkm::IdComponent numComponents = VTraits::GetNumberOfComponents(average);
|
||||||
}
|
for (vtkm::IdComponent cIndex = 0; cIndex < numComponents; ++cIndex)
|
||||||
|
|
||||||
template <typename PointValueVecType, typename OutType>
|
|
||||||
VTKM_EXEC void DoAverage(const vtkm::IdComponent& vtkmNotUsed(numPoints),
|
|
||||||
const PointValueVecType& vtkmNotUsed(pointValues),
|
|
||||||
OutType& vtkmNotUsed(average),
|
|
||||||
std::false_type) const
|
|
||||||
{
|
{
|
||||||
this->RaiseError("CellAverage called with mismatched Vec sizes for CellAverage.");
|
VTraits::SetComponent(
|
||||||
|
average,
|
||||||
|
cIndex,
|
||||||
|
static_cast<OutComponentType>(VTraits::GetComponent(average, cIndex) / numPoints));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -114,7 +114,10 @@ static void TestAvgPointToCell()
|
|||||||
vtkm::cont::ArrayHandle<vtkm::Float32> result;
|
vtkm::cont::ArrayHandle<vtkm::Float32> result;
|
||||||
|
|
||||||
vtkm::worklet::DispatcherMapTopology<vtkm::worklet::CellAverage> dispatcher;
|
vtkm::worklet::DispatcherMapTopology<vtkm::worklet::CellAverage> dispatcher;
|
||||||
dispatcher.Invoke(&cellset, dataSet.GetField("pointvar"), &result);
|
dispatcher.Invoke(
|
||||||
|
&cellset,
|
||||||
|
dataSet.GetField("pointvar").GetData().AsArrayHandle<vtkm::cont::ArrayHandle<vtkm::Float32>>(),
|
||||||
|
&result);
|
||||||
|
|
||||||
std::cout << "Make sure we got the right answer." << std::endl;
|
std::cout << "Make sure we got the right answer." << std::endl;
|
||||||
VTKM_TEST_ASSERT(test_equal(result.ReadPortal().Get(0), 20.1333f),
|
VTKM_TEST_ASSERT(test_equal(result.ReadPortal().Get(0), 20.1333f),
|
||||||
@ -126,8 +129,11 @@ static void TestAvgPointToCell()
|
|||||||
bool exceptionThrown = false;
|
bool exceptionThrown = false;
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
dispatcher.Invoke(dataSet.GetCellSet(),
|
dispatcher.Invoke(
|
||||||
dataSet.GetField("cellvar"), // should be pointvar
|
dataSet.GetCellSet(),
|
||||||
|
dataSet.GetField("cellvar")
|
||||||
|
.GetData()
|
||||||
|
.AsArrayHandle<vtkm::cont::ArrayHandle<vtkm::Float32>>(), // should be pointvar
|
||||||
result);
|
result);
|
||||||
}
|
}
|
||||||
catch (vtkm::cont::ErrorBadValue& error)
|
catch (vtkm::cont::ErrorBadValue& error)
|
||||||
|
@ -147,7 +147,7 @@ static void TestAvgPointToCell()
|
|||||||
// of the way we get cell indices. We need to make that
|
// of the way we get cell indices. We need to make that
|
||||||
// part more flexible.
|
// part more flexible.
|
||||||
&cellset,
|
&cellset,
|
||||||
dataSet.GetField("pointvar"),
|
dataSet.GetField("pointvar").GetData().AsArrayHandle<vtkm::cont::ArrayHandle<vtkm::Float32>>(),
|
||||||
result);
|
result);
|
||||||
|
|
||||||
std::cout << "Make sure we got the right answer." << std::endl;
|
std::cout << "Make sure we got the right answer." << std::endl;
|
||||||
@ -165,7 +165,9 @@ static void TestAvgPointToCell()
|
|||||||
// of the way we get cell indices. We need to make that
|
// of the way we get cell indices. We need to make that
|
||||||
// part more flexible.
|
// part more flexible.
|
||||||
dataSet.GetCellSet().ResetCellSetList(vtkm::cont::CellSetListStructured2D()),
|
dataSet.GetCellSet().ResetCellSetList(vtkm::cont::CellSetListStructured2D()),
|
||||||
dataSet.GetField("cellvar"), // should be pointvar
|
dataSet.GetField("cellvar")
|
||||||
|
.GetData()
|
||||||
|
.AsArrayHandle<vtkm::cont::ArrayHandle<vtkm::Float32>>(), // should be pointvar
|
||||||
result);
|
result);
|
||||||
}
|
}
|
||||||
catch (vtkm::cont::ErrorBadValue& error)
|
catch (vtkm::cont::ErrorBadValue& error)
|
||||||
|
Loading…
Reference in New Issue
Block a user