From 6e1de4fa8f673eb6c070b4b16bf54f42cc5904f1 Mon Sep 17 00:00:00 2001 From: Sujin Philip Date: Mon, 23 Oct 2023 10:49:25 -0400 Subject: [PATCH] Fix threshold for any-point-all-components case Threshold was producing wrong results with options `SetAllInRange(false)` and `SetComponentToTestToAll` because the logic of running `worklet::Threshold::RunIncremental` on individual components of the input field and combining the results is incorrect for this case. With this fix, component modes 'Any' and 'All' are handled by applying the threshold criteria to each component of each value of the field, combining the results, and running the threshold worklet on the result array. --- vtkm/filter/entity_extraction/Threshold.cxx | 80 +++++++++++++++------ 1 file changed, 57 insertions(+), 23 deletions(-) diff --git a/vtkm/filter/entity_extraction/Threshold.cxx b/vtkm/filter/entity_extraction/Threshold.cxx index 1d3c5765a..eac4faf02 100644 --- a/vtkm/filter/entity_extraction/Threshold.cxx +++ b/vtkm/filter/entity_extraction/Threshold.cxx @@ -11,6 +11,8 @@ #include #include +#include + #include #include @@ -59,6 +61,37 @@ bool DoMapField(vtkm::cont::DataSet& result, return false; } } + +template +class CombinePassFlagsWorklet : public vtkm::worklet::WorkletMapField +{ +public: + using ControlSignature = void(FieldInOut, FieldIn); + using ExecitionSignature = void(_1, _2); + + VTKM_CONT + explicit CombinePassFlagsWorklet(const Operator& combine) + : Combine(combine) + { + } + + VTKM_EXEC void operator()(bool& combined, bool incoming) const + { + combined = this->Combine(combined, incoming); + } + +private: + Operator Combine; +}; + +class ThresholdPassFlag +{ +public: + VTKM_CONT ThresholdPassFlag() = default; + + VTKM_EXEC bool operator()(bool value) const { return value; } +}; + } // end anon namespace namespace vtkm @@ -104,7 +137,7 @@ vtkm::cont::DataSet Threshold::DoExecute(const vtkm::cont::DataSet& input) return; } - if (this->ComponentMode == Component::Selected) + if (this->ComponentMode == Component::Selected || field.GetData().GetNumberOfComponents() == 1) { auto arrayComponent = field.GetData().ExtractComponent(this->SelectedComponent); @@ -113,35 +146,36 @@ vtkm::cont::DataSet Threshold::DoExecute(const vtkm::cont::DataSet& input) } else { - for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i) + vtkm::cont::ArrayHandle passFlags; + if (this->ComponentMode == Component::Any) { - auto arrayComponent = field.GetData().ExtractComponent(i); - if (this->ComponentMode == Component::Any) + auto combineWorklet = CombinePassFlagsWorklet(vtkm::LogicalOr{}); + passFlags.AllocateAndFill(field.GetNumberOfValues(), false); + for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i) { - worklet.RunIncremental(cells, - arrayComponent, - field.GetAssociation(), - predicate, - this->AllInRange, - vtkm::LogicalOr{}); + auto arrayComponent = field.GetData().ExtractComponent(i); + auto thresholded = vtkm::cont::make_ArrayHandleTransform(arrayComponent, predicate); + vtkm::cont::Invoker()(combineWorklet, passFlags, thresholded); } - else // this->ComponentMode == Component::All + } + else // this->ComponentMode == Component::All + { + auto combineWorklet = CombinePassFlagsWorklet(vtkm::LogicalAnd{}); + passFlags.AllocateAndFill(field.GetNumberOfValues(), true); + for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i) { - worklet.RunIncremental(cells, - arrayComponent, - field.GetAssociation(), - predicate, - this->AllInRange, - vtkm::LogicalAnd{}); + auto arrayComponent = field.GetData().ExtractComponent(i); + auto thresholded = vtkm::cont::make_ArrayHandleTransform(arrayComponent, predicate); + vtkm::cont::Invoker()(combineWorklet, passFlags, thresholded); } } - if (this->Invert) - { - worklet.InvertResults(); - } - - cellOut = worklet.GenerateResultCellSet(cells); + cellOut = worklet.Run(cells, + passFlags, + field.GetAssociation(), + ThresholdPassFlag{}, + this->AllInRange, + this->Invert); } };