diff --git a/vtkm/filter/entity_extraction/Threshold.cxx b/vtkm/filter/entity_extraction/Threshold.cxx index 1d3c5765a..eac4faf02 100644 --- a/vtkm/filter/entity_extraction/Threshold.cxx +++ b/vtkm/filter/entity_extraction/Threshold.cxx @@ -11,6 +11,8 @@ #include #include +#include + #include #include @@ -59,6 +61,37 @@ bool DoMapField(vtkm::cont::DataSet& result, return false; } } + +template +class CombinePassFlagsWorklet : public vtkm::worklet::WorkletMapField +{ +public: + using ControlSignature = void(FieldInOut, FieldIn); + using ExecitionSignature = void(_1, _2); + + VTKM_CONT + explicit CombinePassFlagsWorklet(const Operator& combine) + : Combine(combine) + { + } + + VTKM_EXEC void operator()(bool& combined, bool incoming) const + { + combined = this->Combine(combined, incoming); + } + +private: + Operator Combine; +}; + +class ThresholdPassFlag +{ +public: + VTKM_CONT ThresholdPassFlag() = default; + + VTKM_EXEC bool operator()(bool value) const { return value; } +}; + } // end anon namespace namespace vtkm @@ -104,7 +137,7 @@ vtkm::cont::DataSet Threshold::DoExecute(const vtkm::cont::DataSet& input) return; } - if (this->ComponentMode == Component::Selected) + if (this->ComponentMode == Component::Selected || field.GetData().GetNumberOfComponents() == 1) { auto arrayComponent = field.GetData().ExtractComponent(this->SelectedComponent); @@ -113,35 +146,36 @@ vtkm::cont::DataSet Threshold::DoExecute(const vtkm::cont::DataSet& input) } else { - for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i) + vtkm::cont::ArrayHandle passFlags; + if (this->ComponentMode == Component::Any) { - auto arrayComponent = field.GetData().ExtractComponent(i); - if (this->ComponentMode == Component::Any) + auto combineWorklet = CombinePassFlagsWorklet(vtkm::LogicalOr{}); + passFlags.AllocateAndFill(field.GetNumberOfValues(), false); + for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i) { - worklet.RunIncremental(cells, - arrayComponent, - field.GetAssociation(), - predicate, - this->AllInRange, - vtkm::LogicalOr{}); + auto arrayComponent = field.GetData().ExtractComponent(i); + auto thresholded = vtkm::cont::make_ArrayHandleTransform(arrayComponent, predicate); + vtkm::cont::Invoker()(combineWorklet, passFlags, thresholded); } - else // this->ComponentMode == Component::All + } + else // this->ComponentMode == Component::All + { + auto combineWorklet = CombinePassFlagsWorklet(vtkm::LogicalAnd{}); + passFlags.AllocateAndFill(field.GetNumberOfValues(), true); + for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i) { - worklet.RunIncremental(cells, - arrayComponent, - field.GetAssociation(), - predicate, - this->AllInRange, - vtkm::LogicalAnd{}); + auto arrayComponent = field.GetData().ExtractComponent(i); + auto thresholded = vtkm::cont::make_ArrayHandleTransform(arrayComponent, predicate); + vtkm::cont::Invoker()(combineWorklet, passFlags, thresholded); } } - if (this->Invert) - { - worklet.InvertResults(); - } - - cellOut = worklet.GenerateResultCellSet(cells); + cellOut = worklet.Run(cells, + passFlags, + field.GetAssociation(), + ThresholdPassFlag{}, + this->AllInRange, + this->Invert); } }; diff --git a/vtkm/filter/entity_extraction/testing/UnitTestThresholdFilter.cxx b/vtkm/filter/entity_extraction/testing/UnitTestThresholdFilter.cxx index 9ade1efb0..d9445e1ba 100644 --- a/vtkm/filter/entity_extraction/testing/UnitTestThresholdFilter.cxx +++ b/vtkm/filter/entity_extraction/testing/UnitTestThresholdFilter.cxx @@ -311,6 +311,27 @@ public: VTKM_TEST_ASSERT(failures == 0, "Some combinations have failed"); } + // Regression test for issue #804 + static void RegressionTest804() + { + std::cout << "Regression test for issue #804" << std::endl; + + auto input = vtkm::cont::DataSetBuilderUniform::Create(vtkm::Id2{ 4, 2 }); + static const vtkm::Vec2f pointvar[8] = { { 0.0f, 7.0f }, { 1.0f, 6.0f }, { 2.0f, 5.0f }, + { 3.0f, 4.0f }, { 4.0f, 3.0f }, { 5.0f, 2.0f }, + { 6.0f, 1.0f }, { 7.0f, 0.0f } }; + input.AddPointField("pointvar", pointvar, 8); + + vtkm::filter::entity_extraction::Threshold threshold; + threshold.SetActiveField("pointvar"); + threshold.SetAllInRange(false); + threshold.SetThresholdBelow(4.0); + threshold.SetComponentToTestToAll(); + auto output = threshold.Execute(input); + auto numOutputCells = output.GetNumberOfCells(); + VTKM_TEST_ASSERT(numOutputCells == 2, "Wrong number of cells in the output"); + } + void operator()() const { TestingThreshold::TestRegular2D(false); @@ -320,6 +341,7 @@ public: TestingThreshold::TestExplicit3D(); TestingThreshold::TestExplicit3DZeroResults(); TestingThreshold::TestAllOptions(); + TestingThreshold::RegressionTest804(); } }; } diff --git a/vtkm/filter/entity_extraction/worklet/Threshold.h b/vtkm/filter/entity_extraction/worklet/Threshold.h index f940196e4..4df0e00d7 100644 --- a/vtkm/filter/entity_extraction/worklet/Threshold.h +++ b/vtkm/filter/entity_extraction/worklet/Threshold.h @@ -79,72 +79,18 @@ public: bool AllPointsMustPass; }; - template - class CombinePassFlagsWorklet : public vtkm::worklet::WorkletMapField + template + vtkm::cont::CellSetPermutation RunImpl( + const CellSetType& cellSet, + const vtkm::cont::ArrayHandle& field, + vtkm::cont::Field::Association fieldType, + const UnaryPredicate& predicate, + bool allPointsMustPass, + bool invert) { - public: - using ControlSignature = void(FieldInOut, FieldIn); - using ExecitionSignature = void(_1, _2); + using OutputType = vtkm::cont::CellSetPermutation; - VTKM_CONT - explicit CombinePassFlagsWorklet(const Operator& combine) - : Combine(combine) - { - } - - VTKM_EXEC void operator()(bool& combined, bool incoming) const - { - combined = this->Combine(combined, incoming); - } - - private: - Operator Combine; - }; - - template - void CombinePassFlags(const vtkm::cont::ArrayHandle& passFlagsIn, const Operator& combine) - { - if (this->PassFlags.GetNumberOfValues() == 0) // Is initialization needed? - { - this->PassFlags = passFlagsIn; - } - else - { - DispatcherMapField> dispatcher( - CombinePassFlagsWorklet{ combine }); - dispatcher.Invoke(this->PassFlags, passFlagsIn); - } - this->PassFlagsModified = true; - } - - // special no-op combine operator for combining `PassFlags` results of incremental runs - struct NoOp - { - }; - - void CombinePassFlags(const vtkm::cont::ArrayHandle& passFlagsIn, NoOp) - { - this->PassFlags = passFlagsIn; - this->PassFlagsModified = true; - } - - /// Incrementally run the worklet on the given parameters. Each run should get the - /// same `cellSet`. An array of pass/fail flags is maintained internally. The `passFlagsCombine` - /// operator is used to combine the current result to the incremental results. Finally, use - /// `GenerateResultCellSet` to get the thresholded cellset. - template - void RunIncremental(const vtkm::cont::UnknownCellSet& cellSet, - const vtkm::cont::ArrayHandle& field, - vtkm::cont::Field::Association fieldType, - const UnaryPredicate& predicate, - bool allPointsMustPass, // only considered when field association is `Points` - const PassFlagsCombineOp& passFlagsCombineOp) - { vtkm::cont::ArrayHandle passFlags; - switch (fieldType) { case vtkm::cont::Field::Association::Points: @@ -166,40 +112,16 @@ public: throw vtkm::cont::ErrorBadValue("Expecting point or cell field."); } - this->CombinePassFlags(passFlags, passFlagsCombineOp); - } - - vtkm::cont::ArrayHandle GetValidCellIds() const - { - if (this->PassFlagsModified) + if (invert) { - vtkm::cont::Algorithm::CopyIf( - vtkm::cont::ArrayHandleIndex(this->PassFlags.GetNumberOfValues()), - this->PassFlags, - this->ValidCellIds); - this->PassFlagsModified = false; + vtkm::cont::Algorithm::Copy( + vtkm::cont::make_ArrayHandleTransform(passFlags, vtkm::LogicalNot{}), passFlags); } - return this->ValidCellIds; - } - vtkm::cont::UnknownCellSet GenerateResultCellSet(const vtkm::cont::UnknownCellSet& cellSet) - { - vtkm::cont::UnknownCellSet output; + vtkm::cont::Algorithm::CopyIf( + vtkm::cont::ArrayHandleIndex(passFlags.GetNumberOfValues()), passFlags, this->ValidCellIds); - CastAndCall(cellSet, [&](auto concrete) { - output = vtkm::worklet::CellDeepCopy::Run( - vtkm::cont::make_CellSetPermutation(this->GetValidCellIds(), concrete)); - }); - - return output; - } - - // Invert the results stored in this worklet's state - void InvertResults() - { - vtkm::cont::Algorithm::Copy( - vtkm::cont::make_ArrayHandleTransform(this->PassFlags, vtkm::LogicalNot{}), this->PassFlags); - this->PassFlagsModified = true; + return OutputType(this->ValidCellIds, cellSet); } template @@ -211,19 +133,18 @@ public: bool allPointsMustPass = false, // only considered when field association is `Points` bool invert = false) { - this->RunIncremental(cellSet, field, fieldType, predicate, allPointsMustPass, NoOp{}); - if (invert) - { - this->InvertResults(); - } - return this->GenerateResultCellSet(cellSet); + vtkm::cont::UnknownCellSet output; + CastAndCall(cellSet, [&](auto concrete) { + output = vtkm::worklet::CellDeepCopy::Run( + this->RunImpl(concrete, field, fieldType, predicate, allPointsMustPass, invert)); + }); + return output; } -private: - vtkm::cont::ArrayHandle PassFlags; + vtkm::cont::ArrayHandle GetValidCellIds() const { return this->ValidCellIds; } - mutable bool PassFlagsModified = true; - mutable vtkm::cont::ArrayHandle ValidCellIds; +private: + vtkm::cont::ArrayHandle ValidCellIds; }; } } // namespace vtkm::worklet