mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-10-05 01:49:02 +00:00
Merge topic 'fix-threshold-all-comps' into release-2.1
d825d2450 Add a regression test for issue #804 e0c5500a2 Simplify threshold worklet 6e1de4fa8 Fix threshold for any-point-all-components case Acked-by: Kitware Robot <kwrobot@kitware.com> Acked-by: Kenneth Moreland <morelandkd@ornl.gov> Merge-request: !3143
This commit is contained in:
commit
709961393f
@ -11,6 +11,8 @@
|
||||
#include <vtkm/filter/entity_extraction/Threshold.h>
|
||||
#include <vtkm/filter/entity_extraction/worklet/Threshold.h>
|
||||
|
||||
#include <vtkm/cont/Invoker.h>
|
||||
|
||||
#include <vtkm/BinaryPredicates.h>
|
||||
#include <vtkm/Math.h>
|
||||
|
||||
@ -59,6 +61,37 @@ bool DoMapField(vtkm::cont::DataSet& result,
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Operator>
|
||||
class CombinePassFlagsWorklet : public vtkm::worklet::WorkletMapField
|
||||
{
|
||||
public:
|
||||
using ControlSignature = void(FieldInOut, FieldIn);
|
||||
using ExecitionSignature = void(_1, _2);
|
||||
|
||||
VTKM_CONT
|
||||
explicit CombinePassFlagsWorklet(const Operator& combine)
|
||||
: Combine(combine)
|
||||
{
|
||||
}
|
||||
|
||||
VTKM_EXEC void operator()(bool& combined, bool incoming) const
|
||||
{
|
||||
combined = this->Combine(combined, incoming);
|
||||
}
|
||||
|
||||
private:
|
||||
Operator Combine;
|
||||
};
|
||||
|
||||
class ThresholdPassFlag
|
||||
{
|
||||
public:
|
||||
VTKM_CONT ThresholdPassFlag() = default;
|
||||
|
||||
VTKM_EXEC bool operator()(bool value) const { return value; }
|
||||
};
|
||||
|
||||
} // end anon namespace
|
||||
|
||||
namespace vtkm
|
||||
@ -104,7 +137,7 @@ vtkm::cont::DataSet Threshold::DoExecute(const vtkm::cont::DataSet& input)
|
||||
return;
|
||||
}
|
||||
|
||||
if (this->ComponentMode == Component::Selected)
|
||||
if (this->ComponentMode == Component::Selected || field.GetData().GetNumberOfComponents() == 1)
|
||||
{
|
||||
auto arrayComponent =
|
||||
field.GetData().ExtractComponent<ComponentType>(this->SelectedComponent);
|
||||
@ -113,35 +146,36 @@ vtkm::cont::DataSet Threshold::DoExecute(const vtkm::cont::DataSet& input)
|
||||
}
|
||||
else
|
||||
{
|
||||
for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i)
|
||||
vtkm::cont::ArrayHandle<bool> passFlags;
|
||||
if (this->ComponentMode == Component::Any)
|
||||
{
|
||||
auto arrayComponent = field.GetData().ExtractComponent<ComponentType>(i);
|
||||
if (this->ComponentMode == Component::Any)
|
||||
auto combineWorklet = CombinePassFlagsWorklet<vtkm::LogicalOr>(vtkm::LogicalOr{});
|
||||
passFlags.AllocateAndFill(field.GetNumberOfValues(), false);
|
||||
for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i)
|
||||
{
|
||||
worklet.RunIncremental(cells,
|
||||
arrayComponent,
|
||||
field.GetAssociation(),
|
||||
predicate,
|
||||
this->AllInRange,
|
||||
vtkm::LogicalOr{});
|
||||
auto arrayComponent = field.GetData().ExtractComponent<ComponentType>(i);
|
||||
auto thresholded = vtkm::cont::make_ArrayHandleTransform(arrayComponent, predicate);
|
||||
vtkm::cont::Invoker()(combineWorklet, passFlags, thresholded);
|
||||
}
|
||||
else // this->ComponentMode == Component::All
|
||||
}
|
||||
else // this->ComponentMode == Component::All
|
||||
{
|
||||
auto combineWorklet = CombinePassFlagsWorklet<vtkm::LogicalAnd>(vtkm::LogicalAnd{});
|
||||
passFlags.AllocateAndFill(field.GetNumberOfValues(), true);
|
||||
for (vtkm::IdComponent i = 0; i < field.GetData().GetNumberOfComponents(); ++i)
|
||||
{
|
||||
worklet.RunIncremental(cells,
|
||||
arrayComponent,
|
||||
field.GetAssociation(),
|
||||
predicate,
|
||||
this->AllInRange,
|
||||
vtkm::LogicalAnd{});
|
||||
auto arrayComponent = field.GetData().ExtractComponent<ComponentType>(i);
|
||||
auto thresholded = vtkm::cont::make_ArrayHandleTransform(arrayComponent, predicate);
|
||||
vtkm::cont::Invoker()(combineWorklet, passFlags, thresholded);
|
||||
}
|
||||
}
|
||||
|
||||
if (this->Invert)
|
||||
{
|
||||
worklet.InvertResults();
|
||||
}
|
||||
|
||||
cellOut = worklet.GenerateResultCellSet(cells);
|
||||
cellOut = worklet.Run(cells,
|
||||
passFlags,
|
||||
field.GetAssociation(),
|
||||
ThresholdPassFlag{},
|
||||
this->AllInRange,
|
||||
this->Invert);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -311,6 +311,27 @@ public:
|
||||
VTKM_TEST_ASSERT(failures == 0, "Some combinations have failed");
|
||||
}
|
||||
|
||||
// Regression test for issue #804
|
||||
static void RegressionTest804()
|
||||
{
|
||||
std::cout << "Regression test for issue #804" << std::endl;
|
||||
|
||||
auto input = vtkm::cont::DataSetBuilderUniform::Create(vtkm::Id2{ 4, 2 });
|
||||
static const vtkm::Vec2f pointvar[8] = { { 0.0f, 7.0f }, { 1.0f, 6.0f }, { 2.0f, 5.0f },
|
||||
{ 3.0f, 4.0f }, { 4.0f, 3.0f }, { 5.0f, 2.0f },
|
||||
{ 6.0f, 1.0f }, { 7.0f, 0.0f } };
|
||||
input.AddPointField("pointvar", pointvar, 8);
|
||||
|
||||
vtkm::filter::entity_extraction::Threshold threshold;
|
||||
threshold.SetActiveField("pointvar");
|
||||
threshold.SetAllInRange(false);
|
||||
threshold.SetThresholdBelow(4.0);
|
||||
threshold.SetComponentToTestToAll();
|
||||
auto output = threshold.Execute(input);
|
||||
auto numOutputCells = output.GetNumberOfCells();
|
||||
VTKM_TEST_ASSERT(numOutputCells == 2, "Wrong number of cells in the output");
|
||||
}
|
||||
|
||||
void operator()() const
|
||||
{
|
||||
TestingThreshold::TestRegular2D(false);
|
||||
@ -320,6 +341,7 @@ public:
|
||||
TestingThreshold::TestExplicit3D();
|
||||
TestingThreshold::TestExplicit3DZeroResults();
|
||||
TestingThreshold::TestAllOptions();
|
||||
TestingThreshold::RegressionTest804();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -79,72 +79,18 @@ public:
|
||||
bool AllPointsMustPass;
|
||||
};
|
||||
|
||||
template <typename Operator>
|
||||
class CombinePassFlagsWorklet : public vtkm::worklet::WorkletMapField
|
||||
template <typename CellSetType, typename ValueType, typename StorageType, typename UnaryPredicate>
|
||||
vtkm::cont::CellSetPermutation<CellSetType> RunImpl(
|
||||
const CellSetType& cellSet,
|
||||
const vtkm::cont::ArrayHandle<ValueType, StorageType>& field,
|
||||
vtkm::cont::Field::Association fieldType,
|
||||
const UnaryPredicate& predicate,
|
||||
bool allPointsMustPass,
|
||||
bool invert)
|
||||
{
|
||||
public:
|
||||
using ControlSignature = void(FieldInOut, FieldIn);
|
||||
using ExecitionSignature = void(_1, _2);
|
||||
using OutputType = vtkm::cont::CellSetPermutation<CellSetType>;
|
||||
|
||||
VTKM_CONT
|
||||
explicit CombinePassFlagsWorklet(const Operator& combine)
|
||||
: Combine(combine)
|
||||
{
|
||||
}
|
||||
|
||||
VTKM_EXEC void operator()(bool& combined, bool incoming) const
|
||||
{
|
||||
combined = this->Combine(combined, incoming);
|
||||
}
|
||||
|
||||
private:
|
||||
Operator Combine;
|
||||
};
|
||||
|
||||
template <typename Operator>
|
||||
void CombinePassFlags(const vtkm::cont::ArrayHandle<bool>& passFlagsIn, const Operator& combine)
|
||||
{
|
||||
if (this->PassFlags.GetNumberOfValues() == 0) // Is initialization needed?
|
||||
{
|
||||
this->PassFlags = passFlagsIn;
|
||||
}
|
||||
else
|
||||
{
|
||||
DispatcherMapField<CombinePassFlagsWorklet<Operator>> dispatcher(
|
||||
CombinePassFlagsWorklet<Operator>{ combine });
|
||||
dispatcher.Invoke(this->PassFlags, passFlagsIn);
|
||||
}
|
||||
this->PassFlagsModified = true;
|
||||
}
|
||||
|
||||
// special no-op combine operator for combining `PassFlags` results of incremental runs
|
||||
struct NoOp
|
||||
{
|
||||
};
|
||||
|
||||
void CombinePassFlags(const vtkm::cont::ArrayHandle<bool>& passFlagsIn, NoOp)
|
||||
{
|
||||
this->PassFlags = passFlagsIn;
|
||||
this->PassFlagsModified = true;
|
||||
}
|
||||
|
||||
/// Incrementally run the worklet on the given parameters. Each run should get the
|
||||
/// same `cellSet`. An array of pass/fail flags is maintained internally. The `passFlagsCombine`
|
||||
/// operator is used to combine the current result to the incremental results. Finally, use
|
||||
/// `GenerateResultCellSet` to get the thresholded cellset.
|
||||
template <typename ValueType,
|
||||
typename StorageType,
|
||||
typename UnaryPredicate,
|
||||
typename PassFlagsCombineOp>
|
||||
void RunIncremental(const vtkm::cont::UnknownCellSet& cellSet,
|
||||
const vtkm::cont::ArrayHandle<ValueType, StorageType>& field,
|
||||
vtkm::cont::Field::Association fieldType,
|
||||
const UnaryPredicate& predicate,
|
||||
bool allPointsMustPass, // only considered when field association is `Points`
|
||||
const PassFlagsCombineOp& passFlagsCombineOp)
|
||||
{
|
||||
vtkm::cont::ArrayHandle<bool> passFlags;
|
||||
|
||||
switch (fieldType)
|
||||
{
|
||||
case vtkm::cont::Field::Association::Points:
|
||||
@ -166,40 +112,16 @@ public:
|
||||
throw vtkm::cont::ErrorBadValue("Expecting point or cell field.");
|
||||
}
|
||||
|
||||
this->CombinePassFlags(passFlags, passFlagsCombineOp);
|
||||
}
|
||||
|
||||
vtkm::cont::ArrayHandle<vtkm::Id> GetValidCellIds() const
|
||||
{
|
||||
if (this->PassFlagsModified)
|
||||
if (invert)
|
||||
{
|
||||
vtkm::cont::Algorithm::CopyIf(
|
||||
vtkm::cont::ArrayHandleIndex(this->PassFlags.GetNumberOfValues()),
|
||||
this->PassFlags,
|
||||
this->ValidCellIds);
|
||||
this->PassFlagsModified = false;
|
||||
vtkm::cont::Algorithm::Copy(
|
||||
vtkm::cont::make_ArrayHandleTransform(passFlags, vtkm::LogicalNot{}), passFlags);
|
||||
}
|
||||
return this->ValidCellIds;
|
||||
}
|
||||
|
||||
vtkm::cont::UnknownCellSet GenerateResultCellSet(const vtkm::cont::UnknownCellSet& cellSet)
|
||||
{
|
||||
vtkm::cont::UnknownCellSet output;
|
||||
vtkm::cont::Algorithm::CopyIf(
|
||||
vtkm::cont::ArrayHandleIndex(passFlags.GetNumberOfValues()), passFlags, this->ValidCellIds);
|
||||
|
||||
CastAndCall(cellSet, [&](auto concrete) {
|
||||
output = vtkm::worklet::CellDeepCopy::Run(
|
||||
vtkm::cont::make_CellSetPermutation(this->GetValidCellIds(), concrete));
|
||||
});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
// Invert the results stored in this worklet's state
|
||||
void InvertResults()
|
||||
{
|
||||
vtkm::cont::Algorithm::Copy(
|
||||
vtkm::cont::make_ArrayHandleTransform(this->PassFlags, vtkm::LogicalNot{}), this->PassFlags);
|
||||
this->PassFlagsModified = true;
|
||||
return OutputType(this->ValidCellIds, cellSet);
|
||||
}
|
||||
|
||||
template <typename ValueType, typename StorageType, typename UnaryPredicate>
|
||||
@ -211,19 +133,18 @@ public:
|
||||
bool allPointsMustPass = false, // only considered when field association is `Points`
|
||||
bool invert = false)
|
||||
{
|
||||
this->RunIncremental(cellSet, field, fieldType, predicate, allPointsMustPass, NoOp{});
|
||||
if (invert)
|
||||
{
|
||||
this->InvertResults();
|
||||
}
|
||||
return this->GenerateResultCellSet(cellSet);
|
||||
vtkm::cont::UnknownCellSet output;
|
||||
CastAndCall(cellSet, [&](auto concrete) {
|
||||
output = vtkm::worklet::CellDeepCopy::Run(
|
||||
this->RunImpl(concrete, field, fieldType, predicate, allPointsMustPass, invert));
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
private:
|
||||
vtkm::cont::ArrayHandle<bool> PassFlags;
|
||||
vtkm::cont::ArrayHandle<vtkm::Id> GetValidCellIds() const { return this->ValidCellIds; }
|
||||
|
||||
mutable bool PassFlagsModified = true;
|
||||
mutable vtkm::cont::ArrayHandle<vtkm::Id> ValidCellIds;
|
||||
private:
|
||||
vtkm::cont::ArrayHandle<vtkm::Id> ValidCellIds;
|
||||
};
|
||||
}
|
||||
} // namespace vtkm::worklet
|
||||
|
Loading…
Reference in New Issue
Block a user