diff --git a/vtkm/filter/entity_extraction/worklet/Threshold.h b/vtkm/filter/entity_extraction/worklet/Threshold.h index f940196e4..4df0e00d7 100644 --- a/vtkm/filter/entity_extraction/worklet/Threshold.h +++ b/vtkm/filter/entity_extraction/worklet/Threshold.h @@ -79,72 +79,18 @@ public: bool AllPointsMustPass; }; - template - class CombinePassFlagsWorklet : public vtkm::worklet::WorkletMapField + template + vtkm::cont::CellSetPermutation RunImpl( + const CellSetType& cellSet, + const vtkm::cont::ArrayHandle& field, + vtkm::cont::Field::Association fieldType, + const UnaryPredicate& predicate, + bool allPointsMustPass, + bool invert) { - public: - using ControlSignature = void(FieldInOut, FieldIn); - using ExecitionSignature = void(_1, _2); + using OutputType = vtkm::cont::CellSetPermutation; - VTKM_CONT - explicit CombinePassFlagsWorklet(const Operator& combine) - : Combine(combine) - { - } - - VTKM_EXEC void operator()(bool& combined, bool incoming) const - { - combined = this->Combine(combined, incoming); - } - - private: - Operator Combine; - }; - - template - void CombinePassFlags(const vtkm::cont::ArrayHandle& passFlagsIn, const Operator& combine) - { - if (this->PassFlags.GetNumberOfValues() == 0) // Is initialization needed? - { - this->PassFlags = passFlagsIn; - } - else - { - DispatcherMapField> dispatcher( - CombinePassFlagsWorklet{ combine }); - dispatcher.Invoke(this->PassFlags, passFlagsIn); - } - this->PassFlagsModified = true; - } - - // special no-op combine operator for combining `PassFlags` results of incremental runs - struct NoOp - { - }; - - void CombinePassFlags(const vtkm::cont::ArrayHandle& passFlagsIn, NoOp) - { - this->PassFlags = passFlagsIn; - this->PassFlagsModified = true; - } - - /// Incrementally run the worklet on the given parameters. Each run should get the - /// same `cellSet`. An array of pass/fail flags is maintained internally. The `passFlagsCombine` - /// operator is used to combine the current result to the incremental results. Finally, use - /// `GenerateResultCellSet` to get the thresholded cellset. - template - void RunIncremental(const vtkm::cont::UnknownCellSet& cellSet, - const vtkm::cont::ArrayHandle& field, - vtkm::cont::Field::Association fieldType, - const UnaryPredicate& predicate, - bool allPointsMustPass, // only considered when field association is `Points` - const PassFlagsCombineOp& passFlagsCombineOp) - { vtkm::cont::ArrayHandle passFlags; - switch (fieldType) { case vtkm::cont::Field::Association::Points: @@ -166,40 +112,16 @@ public: throw vtkm::cont::ErrorBadValue("Expecting point or cell field."); } - this->CombinePassFlags(passFlags, passFlagsCombineOp); - } - - vtkm::cont::ArrayHandle GetValidCellIds() const - { - if (this->PassFlagsModified) + if (invert) { - vtkm::cont::Algorithm::CopyIf( - vtkm::cont::ArrayHandleIndex(this->PassFlags.GetNumberOfValues()), - this->PassFlags, - this->ValidCellIds); - this->PassFlagsModified = false; + vtkm::cont::Algorithm::Copy( + vtkm::cont::make_ArrayHandleTransform(passFlags, vtkm::LogicalNot{}), passFlags); } - return this->ValidCellIds; - } - vtkm::cont::UnknownCellSet GenerateResultCellSet(const vtkm::cont::UnknownCellSet& cellSet) - { - vtkm::cont::UnknownCellSet output; + vtkm::cont::Algorithm::CopyIf( + vtkm::cont::ArrayHandleIndex(passFlags.GetNumberOfValues()), passFlags, this->ValidCellIds); - CastAndCall(cellSet, [&](auto concrete) { - output = vtkm::worklet::CellDeepCopy::Run( - vtkm::cont::make_CellSetPermutation(this->GetValidCellIds(), concrete)); - }); - - return output; - } - - // Invert the results stored in this worklet's state - void InvertResults() - { - vtkm::cont::Algorithm::Copy( - vtkm::cont::make_ArrayHandleTransform(this->PassFlags, vtkm::LogicalNot{}), this->PassFlags); - this->PassFlagsModified = true; + return OutputType(this->ValidCellIds, cellSet); } template @@ -211,19 +133,18 @@ public: bool allPointsMustPass = false, // only considered when field association is `Points` bool invert = false) { - this->RunIncremental(cellSet, field, fieldType, predicate, allPointsMustPass, NoOp{}); - if (invert) - { - this->InvertResults(); - } - return this->GenerateResultCellSet(cellSet); + vtkm::cont::UnknownCellSet output; + CastAndCall(cellSet, [&](auto concrete) { + output = vtkm::worklet::CellDeepCopy::Run( + this->RunImpl(concrete, field, fieldType, predicate, allPointsMustPass, invert)); + }); + return output; } -private: - vtkm::cont::ArrayHandle PassFlags; + vtkm::cont::ArrayHandle GetValidCellIds() const { return this->ValidCellIds; } - mutable bool PassFlagsModified = true; - mutable vtkm::cont::ArrayHandle ValidCellIds; +private: + vtkm::cont::ArrayHandle ValidCellIds; }; } } // namespace vtkm::worklet