diff --git a/vtkm/filter/entity_extraction/Threshold.cxx b/vtkm/filter/entity_extraction/Threshold.cxx index 1d3c5765a..eac4faf02 100644 --- a/vtkm/filter/entity_extraction/Threshold.cxx +++ b/vtkm/filter/entity_extraction/Threshold.cxx @@ -11,6 +11,8 @@ #include #include +#include + #include #include @@ -59,6 +61,37 @@ bool DoMapField(vtkm::cont::DataSet& result, return false; } } + +template +class CombinePassFlagsWorklet : public vtkm::worklet::WorkletMapField +{ +public: + using ControlSignature = void(FieldInOut, FieldIn); + using ExecitionSignature = void(_1, _2); + + VTKM_CONT + explicit CombinePassFlagsWorklet(const Operator& combine) + : Combine(combine) + { + } + + VTKM_EXEC void operator()(bool& combined, bool incoming) const + { + combined = this->Combine(combined, incoming); + } + +private: + Operator Combine; +}; + +class ThresholdPassFlag +{ +public: + VTKM_CONT ThresholdPassFlag() = default; + + VTKM_EXEC bool operator()(bool value) const { return value; } +}; + } // end anon namespace namespace vtkm @@ -104,7 +137,7 @@ vtkm::cont::DataSet Threshold::DoExecute(const vtkm::cont::DataSet& input) return; } - if (this->ComponentMode == Component::Selected) + if (this->ComponentMode == Component::Selected || field.GetData().GetNumberOfComponents() == 1) { auto arrayComponent = field.GetData().ExtractComponent(this->SelectedComponent); @@ -113,35 +146,36 @@ vtkm::cont::DataSet Threshold::DoExecute(const vtkm::cont::DataSet& input) } else { - for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i) + vtkm::cont::ArrayHandle passFlags; + if (this->ComponentMode == Component::Any) { - auto arrayComponent = field.GetData().ExtractComponent(i); - if (this->ComponentMode == Component::Any) + auto combineWorklet = CombinePassFlagsWorklet(vtkm::LogicalOr{}); + passFlags.AllocateAndFill(field.GetNumberOfValues(), false); + for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i) { - worklet.RunIncremental(cells, - arrayComponent, - field.GetAssociation(), - predicate, - this->AllInRange, - vtkm::LogicalOr{}); + auto arrayComponent = field.GetData().ExtractComponent(i); + auto thresholded = vtkm::cont::make_ArrayHandleTransform(arrayComponent, predicate); + vtkm::cont::Invoker()(combineWorklet, passFlags, thresholded); } - else // this->ComponentMode == Component::All + } + else // this->ComponentMode == Component::All + { + auto combineWorklet = CombinePassFlagsWorklet(vtkm::LogicalAnd{}); + passFlags.AllocateAndFill(field.GetNumberOfValues(), true); + for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i) { - worklet.RunIncremental(cells, - arrayComponent, - field.GetAssociation(), - predicate, - this->AllInRange, - vtkm::LogicalAnd{}); + auto arrayComponent = field.GetData().ExtractComponent(i); + auto thresholded = vtkm::cont::make_ArrayHandleTransform(arrayComponent, predicate); + vtkm::cont::Invoker()(combineWorklet, passFlags, thresholded); } } - if (this->Invert) - { - worklet.InvertResults(); - } - - cellOut = worklet.GenerateResultCellSet(cells); + cellOut = worklet.Run(cells, + passFlags, + field.GetAssociation(), + ThresholdPassFlag{}, + this->AllInRange, + this->Invert); } };