Merge topic 'fix-threshold-all-comps'

d825d2450 Add a regression test for issue #804
e0c5500a2 Simplify threshold worklet
6e1de4fa8 Fix threshold for any-point-all-components case

Acked-by: Kitware Robot <kwrobot@kitware.com>
Acked-by: Kenneth Moreland <morelandkd@ornl.gov>
Merge-request: !3143
This commit is contained in:
Sujin Philip 2023-10-30 16:32:39 +00:00 committed by Kitware Robot
commit 4eb5da26fa
3 changed files with 103 additions and 126 deletions

@ -11,6 +11,8 @@
#include <vtkm/filter/entity_extraction/Threshold.h>
#include <vtkm/filter/entity_extraction/worklet/Threshold.h>
#include <vtkm/cont/Invoker.h>
#include <vtkm/BinaryPredicates.h>
#include <vtkm/Math.h>
@ -59,6 +61,37 @@ bool DoMapField(vtkm::cont::DataSet& result,
return false;
}
}
template <typename Operator>
class CombinePassFlagsWorklet : public vtkm::worklet::WorkletMapField
{
public:
using ControlSignature = void(FieldInOut, FieldIn);
using ExecitionSignature = void(_1, _2);
VTKM_CONT
explicit CombinePassFlagsWorklet(const Operator& combine)
: Combine(combine)
{
}
VTKM_EXEC void operator()(bool& combined, bool incoming) const
{
combined = this->Combine(combined, incoming);
}
private:
Operator Combine;
};
class ThresholdPassFlag
{
public:
VTKM_CONT ThresholdPassFlag() = default;
VTKM_EXEC bool operator()(bool value) const { return value; }
};
} // end anon namespace
namespace vtkm
@ -104,7 +137,7 @@ vtkm::cont::DataSet Threshold::DoExecute(const vtkm::cont::DataSet& input)
return;
}
if (this->ComponentMode == Component::Selected)
if (this->ComponentMode == Component::Selected || field.GetData().GetNumberOfComponents() == 1)
{
auto arrayComponent =
field.GetData().ExtractComponent<ComponentType>(this->SelectedComponent);
@ -113,35 +146,36 @@ vtkm::cont::DataSet Threshold::DoExecute(const vtkm::cont::DataSet& input)
}
else
{
for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i)
vtkm::cont::ArrayHandle<bool> passFlags;
if (this->ComponentMode == Component::Any)
{
auto arrayComponent = field.GetData().ExtractComponent<ComponentType>(i);
if (this->ComponentMode == Component::Any)
auto combineWorklet = CombinePassFlagsWorklet<vtkm::LogicalOr>(vtkm::LogicalOr{});
passFlags.AllocateAndFill(field.GetNumberOfValues(), false);
for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i)
{
worklet.RunIncremental(cells,
arrayComponent,
field.GetAssociation(),
predicate,
this->AllInRange,
vtkm::LogicalOr{});
auto arrayComponent = field.GetData().ExtractComponent<ComponentType>(i);
auto thresholded = vtkm::cont::make_ArrayHandleTransform(arrayComponent, predicate);
vtkm::cont::Invoker()(combineWorklet, passFlags, thresholded);
}
else // this->ComponentMode == Component::All
}
else // this->ComponentMode == Component::All
{
auto combineWorklet = CombinePassFlagsWorklet<vtkm::LogicalAnd>(vtkm::LogicalAnd{});
passFlags.AllocateAndFill(field.GetNumberOfValues(), true);
for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i)
{
worklet.RunIncremental(cells,
arrayComponent,
field.GetAssociation(),
predicate,
this->AllInRange,
vtkm::LogicalAnd{});
auto arrayComponent = field.GetData().ExtractComponent<ComponentType>(i);
auto thresholded = vtkm::cont::make_ArrayHandleTransform(arrayComponent, predicate);
vtkm::cont::Invoker()(combineWorklet, passFlags, thresholded);
}
}
if (this->Invert)
{
worklet.InvertResults();
}
cellOut = worklet.GenerateResultCellSet(cells);
cellOut = worklet.Run(cells,
passFlags,
field.GetAssociation(),
ThresholdPassFlag{},
this->AllInRange,
this->Invert);
}
};

@ -311,6 +311,27 @@ public:
VTKM_TEST_ASSERT(failures == 0, "Some combinations have failed");
}
// Regression test for issue #804
static void RegressionTest804()
{
std::cout << "Regression test for issue #804" << std::endl;
auto input = vtkm::cont::DataSetBuilderUniform::Create(vtkm::Id2{ 4, 2 });
static const vtkm::Vec2f pointvar[8] = { { 0.0f, 7.0f }, { 1.0f, 6.0f }, { 2.0f, 5.0f },
{ 3.0f, 4.0f }, { 4.0f, 3.0f }, { 5.0f, 2.0f },
{ 6.0f, 1.0f }, { 7.0f, 0.0f } };
input.AddPointField("pointvar", pointvar, 8);
vtkm::filter::entity_extraction::Threshold threshold;
threshold.SetActiveField("pointvar");
threshold.SetAllInRange(false);
threshold.SetThresholdBelow(4.0);
threshold.SetComponentToTestToAll();
auto output = threshold.Execute(input);
auto numOutputCells = output.GetNumberOfCells();
VTKM_TEST_ASSERT(numOutputCells == 2, "Wrong number of cells in the output");
}
void operator()() const
{
TestingThreshold::TestRegular2D(false);
@ -320,6 +341,7 @@ public:
TestingThreshold::TestExplicit3D();
TestingThreshold::TestExplicit3DZeroResults();
TestingThreshold::TestAllOptions();
TestingThreshold::RegressionTest804();
}
};
}

@ -79,72 +79,18 @@ public:
bool AllPointsMustPass;
};
template <typename Operator>
class CombinePassFlagsWorklet : public vtkm::worklet::WorkletMapField
template <typename CellSetType, typename ValueType, typename StorageType, typename UnaryPredicate>
vtkm::cont::CellSetPermutation<CellSetType> RunImpl(
const CellSetType& cellSet,
const vtkm::cont::ArrayHandle<ValueType, StorageType>& field,
vtkm::cont::Field::Association fieldType,
const UnaryPredicate& predicate,
bool allPointsMustPass,
bool invert)
{
public:
using ControlSignature = void(FieldInOut, FieldIn);
using ExecitionSignature = void(_1, _2);
using OutputType = vtkm::cont::CellSetPermutation<CellSetType>;
VTKM_CONT
explicit CombinePassFlagsWorklet(const Operator& combine)
: Combine(combine)
{
}
VTKM_EXEC void operator()(bool& combined, bool incoming) const
{
combined = this->Combine(combined, incoming);
}
private:
Operator Combine;
};
template <typename Operator>
void CombinePassFlags(const vtkm::cont::ArrayHandle<bool>& passFlagsIn, const Operator& combine)
{
if (this->PassFlags.GetNumberOfValues() == 0) // Is initialization needed?
{
this->PassFlags = passFlagsIn;
}
else
{
DispatcherMapField<CombinePassFlagsWorklet<Operator>> dispatcher(
CombinePassFlagsWorklet<Operator>{ combine });
dispatcher.Invoke(this->PassFlags, passFlagsIn);
}
this->PassFlagsModified = true;
}
// special no-op combine operator for combining `PassFlags` results of incremental runs
struct NoOp
{
};
void CombinePassFlags(const vtkm::cont::ArrayHandle<bool>& passFlagsIn, NoOp)
{
this->PassFlags = passFlagsIn;
this->PassFlagsModified = true;
}
/// Incrementally run the worklet on the given parameters. Each run should get the
/// same `cellSet`. An array of pass/fail flags is maintained internally. The `passFlagsCombine`
/// operator is used to combine the current result to the incremental results. Finally, use
/// `GenerateResultCellSet` to get the thresholded cellset.
template <typename ValueType,
typename StorageType,
typename UnaryPredicate,
typename PassFlagsCombineOp>
void RunIncremental(const vtkm::cont::UnknownCellSet& cellSet,
const vtkm::cont::ArrayHandle<ValueType, StorageType>& field,
vtkm::cont::Field::Association fieldType,
const UnaryPredicate& predicate,
bool allPointsMustPass, // only considered when field association is `Points`
const PassFlagsCombineOp& passFlagsCombineOp)
{
vtkm::cont::ArrayHandle<bool> passFlags;
switch (fieldType)
{
case vtkm::cont::Field::Association::Points:
@ -166,40 +112,16 @@ public:
throw vtkm::cont::ErrorBadValue("Expecting point or cell field.");
}
this->CombinePassFlags(passFlags, passFlagsCombineOp);
}
vtkm::cont::ArrayHandle<vtkm::Id> GetValidCellIds() const
{
if (this->PassFlagsModified)
if (invert)
{
vtkm::cont::Algorithm::CopyIf(
vtkm::cont::ArrayHandleIndex(this->PassFlags.GetNumberOfValues()),
this->PassFlags,
this->ValidCellIds);
this->PassFlagsModified = false;
vtkm::cont::Algorithm::Copy(
vtkm::cont::make_ArrayHandleTransform(passFlags, vtkm::LogicalNot{}), passFlags);
}
return this->ValidCellIds;
}
vtkm::cont::UnknownCellSet GenerateResultCellSet(const vtkm::cont::UnknownCellSet& cellSet)
{
vtkm::cont::UnknownCellSet output;
vtkm::cont::Algorithm::CopyIf(
vtkm::cont::ArrayHandleIndex(passFlags.GetNumberOfValues()), passFlags, this->ValidCellIds);
CastAndCall(cellSet, [&](auto concrete) {
output = vtkm::worklet::CellDeepCopy::Run(
vtkm::cont::make_CellSetPermutation(this->GetValidCellIds(), concrete));
});
return output;
}
// Invert the results stored in this worklet's state
void InvertResults()
{
vtkm::cont::Algorithm::Copy(
vtkm::cont::make_ArrayHandleTransform(this->PassFlags, vtkm::LogicalNot{}), this->PassFlags);
this->PassFlagsModified = true;
return OutputType(this->ValidCellIds, cellSet);
}
template <typename ValueType, typename StorageType, typename UnaryPredicate>
@ -211,19 +133,18 @@ public:
bool allPointsMustPass = false, // only considered when field association is `Points`
bool invert = false)
{
this->RunIncremental(cellSet, field, fieldType, predicate, allPointsMustPass, NoOp{});
if (invert)
{
this->InvertResults();
}
return this->GenerateResultCellSet(cellSet);
vtkm::cont::UnknownCellSet output;
CastAndCall(cellSet, [&](auto concrete) {
output = vtkm::worklet::CellDeepCopy::Run(
this->RunImpl(concrete, field, fieldType, predicate, allPointsMustPass, invert));
});
return output;
}
private:
vtkm::cont::ArrayHandle<bool> PassFlags;
vtkm::cont::ArrayHandle<vtkm::Id> GetValidCellIds() const { return this->ValidCellIds; }
mutable bool PassFlagsModified = true;
mutable vtkm::cont::ArrayHandle<vtkm::Id> ValidCellIds;
private:
vtkm::cont::ArrayHandle<vtkm::Id> ValidCellIds;
};
}
} // namespace vtkm::worklet