Fix threshold for any-point-all-components case

Threshold was producing wrong results with options `SetAllInRange(false)` and
`SetComponentToTestToAll` because the logic of running
`worklet::Threshold::RunIncremental` on individual components of the input
field and combining the results is incorrect for this case.

With this fix, component modes 'Any' and 'All' are handled by applying
the threshold criteria to each component of each value of the field,
combining the results, and running the threshold worklet on the result
array.
This commit is contained in:
Sujin Philip 2023-10-23 10:49:25 -04:00
parent adf146a96e
commit 6e1de4fa8f

@ -11,6 +11,8 @@
#include <vtkm/filter/entity_extraction/Threshold.h>
#include <vtkm/filter/entity_extraction/worklet/Threshold.h>
#include <vtkm/cont/Invoker.h>
#include <vtkm/BinaryPredicates.h>
#include <vtkm/Math.h>
@ -59,6 +61,37 @@ bool DoMapField(vtkm::cont::DataSet& result,
return false;
}
}
template <typename Operator>
class CombinePassFlagsWorklet : public vtkm::worklet::WorkletMapField
{
public:
using ControlSignature = void(FieldInOut, FieldIn);
using ExecitionSignature = void(_1, _2);
VTKM_CONT
explicit CombinePassFlagsWorklet(const Operator& combine)
: Combine(combine)
{
}
VTKM_EXEC void operator()(bool& combined, bool incoming) const
{
combined = this->Combine(combined, incoming);
}
private:
Operator Combine;
};
class ThresholdPassFlag
{
public:
VTKM_CONT ThresholdPassFlag() = default;
VTKM_EXEC bool operator()(bool value) const { return value; }
};
} // end anon namespace
namespace vtkm
@ -104,7 +137,7 @@ vtkm::cont::DataSet Threshold::DoExecute(const vtkm::cont::DataSet& input)
return;
}
if (this->ComponentMode == Component::Selected)
if (this->ComponentMode == Component::Selected || field.GetData().GetNumberOfComponents() == 1)
{
auto arrayComponent =
field.GetData().ExtractComponent<ComponentType>(this->SelectedComponent);
@ -113,35 +146,36 @@ vtkm::cont::DataSet Threshold::DoExecute(const vtkm::cont::DataSet& input)
}
else
{
for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i)
vtkm::cont::ArrayHandle<bool> passFlags;
if (this->ComponentMode == Component::Any)
{
auto arrayComponent = field.GetData().ExtractComponent<ComponentType>(i);
if (this->ComponentMode == Component::Any)
auto combineWorklet = CombinePassFlagsWorklet<vtkm::LogicalOr>(vtkm::LogicalOr{});
passFlags.AllocateAndFill(field.GetNumberOfValues(), false);
for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i)
{
worklet.RunIncremental(cells,
arrayComponent,
field.GetAssociation(),
predicate,
this->AllInRange,
vtkm::LogicalOr{});
auto arrayComponent = field.GetData().ExtractComponent<ComponentType>(i);
auto thresholded = vtkm::cont::make_ArrayHandleTransform(arrayComponent, predicate);
vtkm::cont::Invoker()(combineWorklet, passFlags, thresholded);
}
else // this->ComponentMode == Component::All
}
else // this->ComponentMode == Component::All
{
auto combineWorklet = CombinePassFlagsWorklet<vtkm::LogicalAnd>(vtkm::LogicalAnd{});
passFlags.AllocateAndFill(field.GetNumberOfValues(), true);
for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i)
{
worklet.RunIncremental(cells,
arrayComponent,
field.GetAssociation(),
predicate,
this->AllInRange,
vtkm::LogicalAnd{});
auto arrayComponent = field.GetData().ExtractComponent<ComponentType>(i);
auto thresholded = vtkm::cont::make_ArrayHandleTransform(arrayComponent, predicate);
vtkm::cont::Invoker()(combineWorklet, passFlags, thresholded);
}
}
if (this->Invert)
{
worklet.InvertResults();
}
cellOut = worklet.GenerateResultCellSet(cells);
cellOut = worklet.Run(cells,
passFlags,
field.GetAssociation(),
ThresholdPassFlag{},
this->AllInRange,
this->Invert);
}
};