From 07c59fcf725a313870458c086c3fe0ce2289329b Mon Sep 17 00:00:00 2001 From: Kenneth Moreland Date: Sun, 8 Sep 2019 00:09:14 -0600 Subject: [PATCH] Update filters with secondary fields to use new policy method Rather than do a CastAndCall on all possible field types when calling a worklet with two fields (where they all typically get cast to the same type as the primary field), use the new mechanism with ArrayHandleMultiplexer to create one code path. Also update the ApplyPolicy to accept the Field type, which is used to determine any additional storage types to support. --- .../apply-policy-with-single-type.md | 13 ++-- vtkm/filter/CrossProduct.hxx | 66 ++++--------------- vtkm/filter/DotProduct.hxx | 66 ++++--------------- vtkm/filter/Filter.h | 11 +++- vtkm/filter/FilterTraits.h | 1 + vtkm/filter/PolicyBase.h | 41 ++++++------ vtkm/filter/Tube.hxx | 11 ++-- vtkm/filter/WarpScalar.h | 5 ++ vtkm/filter/WarpScalar.hxx | 17 ++--- vtkm/filter/WarpVector.h | 3 + vtkm/filter/WarpVector.hxx | 8 +-- vtkm/worklet/Tube.h | 21 +++--- vtkm/worklet/testing/UnitTestTube.cxx | 11 +++- 13 files changed, 110 insertions(+), 164 deletions(-) diff --git a/docs/changelog/apply-policy-with-single-type.md b/docs/changelog/apply-policy-with-single-type.md index fc8e46fb5..0e21a9284 100644 --- a/docs/changelog/apply-policy-with-single-type.md +++ b/docs/changelog/apply-policy-with-single-type.md @@ -17,12 +17,15 @@ have to compile for this array once. This is done through a new version of `ApplyPolicy`. This version takes a type of the array as its first template argument, which must be specified. - + This requires having a list of potential storage to try. It will use that to construct an `ArrayHandleMultiplexer` containing all potential types. This list of storages comes from the policy. A `StorageList` item was added -to the policy. - +to the policy. It is also sometimes necessary for a filter to provide its +own special storage types. Thus, an `AdditionalFieldStorage` type was added +to `Filter` which is set to a `ListTag` of storage types that should be +added to those specified by the policy. + Types are automatically converted. So if you ask for a `vtkm::Float64` and field contains a `vtkm::Float32`, it will the array wrapped in an `ArrayHandleCast` to give the expected type. @@ -33,12 +36,12 @@ result is just going to follow the type of the field. ``` cpp template -inline VTKM_CONT vtkm::cont::DataSet CrossProduct::DoExecute( +inline VTKM_CONT vtkm::cont::DataSet MyFilter::DoExecute( const vtkm::cont::DataSet& inDataSet, const vtkm::cont::ArrayHandle& field, const vtkm::filter::FieldMetadata& fieldMetadata, vtkm::filter::PolicyBase policy) { vtkm::cont::CoordinateSystem coords = inDataSet.GetCoordianteSystem(); - auto coordsArray = vtkm::filter::ApplyPolicy(coords, policy); + auto coordsArray = vtkm::filter::ApplyPolicy(coords, policy, *this); ``` diff --git a/vtkm/filter/CrossProduct.hxx b/vtkm/filter/CrossProduct.hxx index 5bd261c86..0fd3ec76f 100644 --- a/vtkm/filter/CrossProduct.hxx +++ b/vtkm/filter/CrossProduct.hxx @@ -17,31 +17,6 @@ namespace vtkm namespace filter { -namespace detail -{ - -struct CrossProductFunctor -{ - vtkm::cont::Invoker& Invoke; - CrossProductFunctor(vtkm::cont::Invoker& invoke) - : Invoke(invoke) - { - } - - template - void operator()(const SecondaryFieldType& secondaryField, - const vtkm::cont::ArrayHandle, StorageType>& primaryField, - vtkm::cont::ArrayHandle>& output) const - { - this->Invoke(vtkm::worklet::CrossProduct{}, - primaryField, - vtkm::cont::make_ArrayHandleCast>(secondaryField), - output); - } -}; - -} // namespace detail - //----------------------------------------------------------------------------- inline VTKM_CONT CrossProduct::CrossProduct() : vtkm::filter::FilterField() @@ -57,40 +32,23 @@ inline VTKM_CONT CrossProduct::CrossProduct() template inline VTKM_CONT vtkm::cont::DataSet CrossProduct::DoExecute( const vtkm::cont::DataSet& inDataSet, - const vtkm::cont::ArrayHandle, StorageType>& field, + const vtkm::cont::ArrayHandle, StorageType>& primary, const vtkm::filter::FieldMetadata& fieldMetadata, vtkm::filter::PolicyBase policy) { + vtkm::cont::Field secondaryField; + if (this->UseCoordinateSystemAsSecondaryField) + { + secondaryField = inDataSet.GetCoordinateSystem(this->GetSecondaryCoordinateSystemIndex()); + } + else + { + secondaryField = inDataSet.GetField(this->SecondaryFieldName, this->SecondaryFieldAssociation); + } + auto secondary = vtkm::filter::ApplyPolicy>(secondaryField, policy, *this); - detail::CrossProductFunctor functor(this->Invoke); vtkm::cont::ArrayHandle> output; - try - { - if (this->UseCoordinateSystemAsSecondaryField) - { - vtkm::cont::CastAndCall( - inDataSet.GetCoordinateSystem(this->GetSecondaryCoordinateSystemIndex()), - functor, - field, - output); - } - else - { - using Traits = vtkm::filter::FilterTraits; - using TypeList = vtkm::ListTagBase>; - vtkm::filter::ApplyPolicy( - inDataSet.GetField(this->SecondaryFieldName, this->SecondaryFieldAssociation), - policy, - Traits()) - .ResetTypes(TypeList()) - .CastAndCall(functor, field, output); - } - } - catch (const vtkm::cont::Error&) - { - throw vtkm::cont::ErrorExecution("failed to execute."); - } - + this->Invoke(vtkm::worklet::CrossProduct{}, primary, secondary, output); return CreateResult(inDataSet, output, this->GetOutputFieldName(), fieldMetadata); } diff --git a/vtkm/filter/DotProduct.hxx b/vtkm/filter/DotProduct.hxx index 001cde97d..acf1a1c3d 100644 --- a/vtkm/filter/DotProduct.hxx +++ b/vtkm/filter/DotProduct.hxx @@ -15,31 +15,6 @@ namespace vtkm namespace filter { -namespace detail -{ - -struct DotProductFunctor -{ - vtkm::cont::Invoker& Invoke; - DotProductFunctor(vtkm::cont::Invoker& invoke) - : Invoke(invoke) - { - } - - template - void operator()(const SecondaryFieldType& secondaryField, - const vtkm::cont::ArrayHandle, StorageType>& primaryField, - vtkm::cont::ArrayHandle& output) const - { - this->Invoke(vtkm::worklet::DotProduct{}, - primaryField, - vtkm::cont::make_ArrayHandleCast>(secondaryField), - output); - } -}; - -} // namespace detail - //----------------------------------------------------------------------------- inline VTKM_CONT DotProduct::DotProduct() : vtkm::filter::FilterField() @@ -55,38 +30,23 @@ inline VTKM_CONT DotProduct::DotProduct() template inline VTKM_CONT vtkm::cont::DataSet DotProduct::DoExecute( const vtkm::cont::DataSet& inDataSet, - const vtkm::cont::ArrayHandle, StorageType>& field, + const vtkm::cont::ArrayHandle, StorageType>& primary, const vtkm::filter::FieldMetadata& fieldMetadata, vtkm::filter::PolicyBase policy) { - detail::DotProductFunctor functor(this->Invoke); + vtkm::cont::Field secondaryField; + if (this->UseCoordinateSystemAsSecondaryField) + { + secondaryField = inDataSet.GetCoordinateSystem(this->GetSecondaryCoordinateSystemIndex()); + } + else + { + secondaryField = inDataSet.GetField(this->SecondaryFieldName, this->SecondaryFieldAssociation); + } + auto secondary = vtkm::filter::ApplyPolicy>(secondaryField, policy, *this); + vtkm::cont::ArrayHandle output; - try - { - if (this->UseCoordinateSystemAsSecondaryField) - { - vtkm::cont::CastAndCall( - inDataSet.GetCoordinateSystem(this->GetSecondaryCoordinateSystemIndex()), - functor, - field, - output); - } - else - { - using Traits = vtkm::filter::FilterTraits; - using TypeList = vtkm::ListTagBase>; - vtkm::filter::ApplyPolicy( - inDataSet.GetField(this->SecondaryFieldName, this->SecondaryFieldAssociation), - policy, - Traits()) - .ResetTypes(TypeList()) - .CastAndCall(functor, field, output); - } - } - catch (const vtkm::cont::Error&) - { - throw vtkm::cont::ErrorExecution("failed to execute."); - } + this->Invoke(vtkm::worklet::DotProduct{}, primary, secondary, output); return CreateResult(inDataSet, output, this->GetOutputFieldName(), fieldMetadata); } diff --git a/vtkm/filter/Filter.h b/vtkm/filter/Filter.h index bcd2373de..a7f021097 100644 --- a/vtkm/filter/Filter.h +++ b/vtkm/filter/Filter.h @@ -179,7 +179,6 @@ public: VTKM_CONT ~Filter(); - //@{ /// \brief Specify which subset of types a filter supports. /// /// A filter is able to state what subset of types it supports @@ -187,6 +186,16 @@ public: /// filter accepts all types specified by the users provided policy using SupportedTypes = vtkm::ListTagUniversal; + /// \brief Specify which additional field storage to support. + /// + /// When a filter gets a field value from a DataSet, it has to determine what type + /// of storage the array has. Typically this is taken from the policy passed to + /// the filter's execute. In some cases it is useful to support additional types. + /// For example, the filter might make sense to support ArrayHandleIndex or + /// ArrayHandleConstant. If so, the storage of those additional types should be + /// listed here. + using AdditionalFieldStorage = vtkm::ListTagEmpty; + //@{ /// \brief Specify which fields get passed from input to output. /// diff --git a/vtkm/filter/FilterTraits.h b/vtkm/filter/FilterTraits.h index c5fd8ae00..a3714ab1f 100644 --- a/vtkm/filter/FilterTraits.h +++ b/vtkm/filter/FilterTraits.h @@ -35,6 +35,7 @@ struct FilterTraits { using InputFieldTypeList = decltype(detail::as_list(std::declval())); + using AdditionalFieldStorage = typename Filter::AdditionalFieldStorage; }; template diff --git a/vtkm/filter/PolicyBase.h b/vtkm/filter/PolicyBase.h index e2f50da34..51b6b102c 100644 --- a/vtkm/filter/PolicyBase.h +++ b/vtkm/filter/PolicyBase.h @@ -238,12 +238,19 @@ VTKM_CONT vtkm::cont::VariantArrayHandleBase -VTKM_CONT internal::ArrayHandleMultiplexerForStorageList -ApplyPolicy(const vtkm::cont::Field& field, const vtkm::filter::PolicyBase&) +template +VTKM_CONT internal::ArrayHandleMultiplexerForStorageList< + T, + vtkm::ListTagJoin::AdditionalFieldStorage, + typename DerivedPolicy::StorageList>> +ApplyPolicy(const vtkm::cont::Field& field, + vtkm::filter::PolicyBase, + const FilterType&) { - using ArrayHandleMultiplexerType = - internal::ArrayHandleMultiplexerForStorageList; + using ArrayHandleMultiplexerType = internal::ArrayHandleMultiplexerForStorageList< + T, + vtkm::ListTagJoin>; return field.GetData().AsMultiplexer(); } @@ -253,8 +260,8 @@ VTKM_CONT vtkm::cont::VariantArrayHandleBase::InputFieldTypeList>::TypeList> ApplyPolicy(const vtkm::cont::Field& field, - const vtkm::filter::PolicyBase&, - const vtkm::filter::FilterTraits&) + vtkm::filter::PolicyBase, + vtkm::filter::FilterTraits) { using FilterTypes = typename vtkm::filter::FilterTraits::InputFieldTypeList; using TypeList = @@ -266,9 +273,7 @@ ApplyPolicy(const vtkm::cont::Field& field, template VTKM_CONT vtkm::cont::VariantArrayHandleBase< typename vtkm::filter::DeduceFilterFieldTypes::TypeList> -ApplyPolicy(const vtkm::cont::Field& field, - const vtkm::filter::PolicyBase&, - const ListOfTypes&) +ApplyPolicy(const vtkm::cont::Field& field, vtkm::filter::PolicyBase, ListOfTypes) { using TypeList = typename vtkm::filter::DeduceFilterFieldTypes::TypeList; @@ -279,7 +284,7 @@ ApplyPolicy(const vtkm::cont::Field& field, template VTKM_CONT vtkm::cont::DynamicCellSetBase ApplyPolicy( const vtkm::cont::DynamicCellSet& cellset, - const vtkm::filter::PolicyBase&) + vtkm::filter::PolicyBase) { using CellSetList = typename DerivedPolicy::AllCellSetList; return cellset.ResetCellSetList(CellSetList()); @@ -289,7 +294,7 @@ VTKM_CONT vtkm::cont::DynamicCellSetBase template VTKM_CONT vtkm::cont::DynamicCellSetBase ApplyPolicyStructured(const vtkm::cont::DynamicCellSet& cellset, - const vtkm::filter::PolicyBase&) + vtkm::filter::PolicyBase) { using CellSetList = typename DerivedPolicy::StructuredCellSetList; return cellset.ResetCellSetList(CellSetList()); @@ -299,7 +304,7 @@ ApplyPolicyStructured(const vtkm::cont::DynamicCellSet& cellset, template VTKM_CONT vtkm::cont::DynamicCellSetBase ApplyPolicyUnstructured(const vtkm::cont::DynamicCellSet& cellset, - const vtkm::filter::PolicyBase&) + vtkm::filter::PolicyBase) { using CellSetList = typename DerivedPolicy::UnstructuredCellSetList; return cellset.ResetCellSetList(CellSetList()); @@ -308,15 +313,14 @@ ApplyPolicyUnstructured(const vtkm::cont::DynamicCellSet& cellset, //----------------------------------------------------------------------------- template VTKM_CONT vtkm::cont::SerializableField -MakeSerializableField(const vtkm::filter::PolicyBase&) + MakeSerializableField(vtkm::filter::PolicyBase) { return {}; } template VTKM_CONT vtkm::cont::SerializableField -MakeSerializableField(const vtkm::cont::Field& field, - const vtkm::filter::PolicyBase&) +MakeSerializableField(const vtkm::cont::Field& field, vtkm::filter::PolicyBase) { return vtkm::cont::SerializableField{ field }; } @@ -324,7 +328,7 @@ MakeSerializableField(const vtkm::cont::Field& field, template VTKM_CONT vtkm::cont::SerializableDataSet -MakeSerializableDataSet(const vtkm::filter::PolicyBase&) + MakeSerializableDataSet(vtkm::filter::PolicyBase) { return {}; } @@ -332,8 +336,7 @@ MakeSerializableDataSet(const vtkm::filter::PolicyBase&) template VTKM_CONT vtkm::cont::SerializableDataSet -MakeSerializableDataSet(const vtkm::cont::DataSet& dataset, - const vtkm::filter::PolicyBase&) +MakeSerializableDataSet(const vtkm::cont::DataSet& dataset, vtkm::filter::PolicyBase) { return vtkm::cont::SerializableDataSet{ dataset }; diff --git a/vtkm/filter/Tube.hxx b/vtkm/filter/Tube.hxx index cec66cf66..ceca1167f 100644 --- a/vtkm/filter/Tube.hxx +++ b/vtkm/filter/Tube.hxx @@ -12,6 +12,8 @@ #include #include +#include + namespace vtkm { namespace filter @@ -27,18 +29,17 @@ inline VTKM_CONT Tube::Tube() //----------------------------------------------------------------------------- template inline VTKM_CONT vtkm::cont::DataSet Tube::DoExecute(const vtkm::cont::DataSet& input, - vtkm::filter::PolicyBase) + vtkm::filter::PolicyBase policy) { this->Worklet.SetCapping(this->Capping); this->Worklet.SetNumberOfSides(this->NumberOfSides); this->Worklet.SetRadius(this->Radius); + auto originalPoints = vtkm::filter::ApplyPolicy( + input.GetCoordinateSystem(this->GetActiveCoordinateSystemIndex()), policy, *this); vtkm::cont::ArrayHandle newPoints; vtkm::cont::CellSetSingleType<> newCells; - this->Worklet.Run(input.GetCoordinateSystem(this->GetActiveCoordinateSystemIndex()), - input.GetCellSet(), - newPoints, - newCells); + this->Worklet.Run(originalPoints, input.GetCellSet(), newPoints, newCells); vtkm::cont::DataSet outData; vtkm::cont::CoordinateSystem outCoords("coordinates", newPoints); diff --git a/vtkm/filter/WarpScalar.h b/vtkm/filter/WarpScalar.h index 0ba998081..5a9808d89 100644 --- a/vtkm/filter/WarpScalar.h +++ b/vtkm/filter/WarpScalar.h @@ -32,6 +32,11 @@ public: // WarpScalar can only applies to Float and Double Vec3 arrays using SupportedTypes = vtkm::TypeListTagFieldVec3; + // WarpScalar often operates on a constant normal value + using AdditionalFieldStorage = + vtkm::ListTagBase::StorageTag, + vtkm::cont::ArrayHandleConstant::StorageTag>; + VTKM_CONT WarpScalar(vtkm::FloatDefault scaleAmount); diff --git a/vtkm/filter/WarpScalar.hxx b/vtkm/filter/WarpScalar.hxx index 124e0c1ca..1527ddd69 100644 --- a/vtkm/filter/WarpScalar.hxx +++ b/vtkm/filter/WarpScalar.hxx @@ -35,15 +35,16 @@ inline VTKM_CONT vtkm::cont::DataSet WarpScalar::DoExecute( vtkm::filter::PolicyBase policy) { using vecType = vtkm::Vec; - auto normalF = inDataSet.GetField(this->NormalFieldName, this->NormalFieldAssociation); - auto sfF = inDataSet.GetField(this->ScalarFactorFieldName, this->ScalarFactorFieldAssociation); + vtkm::cont::Field normalF = + inDataSet.GetField(this->NormalFieldName, this->NormalFieldAssociation); + vtkm::cont::Field sfF = + inDataSet.GetField(this->ScalarFactorFieldName, this->ScalarFactorFieldAssociation); vtkm::cont::ArrayHandle result; - this->Worklet.Run( - field, - vtkm::filter::ApplyPolicy(normalF, policy, vtkm::filter::FilterTraits()), - vtkm::filter::ApplyPolicy(sfF, policy, vtkm::TypeListTagFieldScalar{}), - this->ScaleAmount, - result); + this->Worklet.Run(field, + vtkm::filter::ApplyPolicy(normalF, policy, *this), + vtkm::filter::ApplyPolicy(sfF, policy, *this), + this->ScaleAmount, + result); return CreateResult(inDataSet, result, this->GetOutputFieldName(), fieldMetadata); } diff --git a/vtkm/filter/WarpVector.h b/vtkm/filter/WarpVector.h index 243e54d76..06e5bf103 100644 --- a/vtkm/filter/WarpVector.h +++ b/vtkm/filter/WarpVector.h @@ -30,6 +30,9 @@ class WarpVector : public vtkm::filter::FilterField { public: using SupportedTypes = vtkm::TypeListTagFieldVec3; + using AdditionalFieldStorage = + vtkm::ListTagBase::StorageTag, + vtkm::cont::ArrayHandleConstant::StorageTag>; VTKM_CONT WarpVector(vtkm::FloatDefault scale); diff --git a/vtkm/filter/WarpVector.hxx b/vtkm/filter/WarpVector.hxx index b86e94ce5..df6492ea5 100644 --- a/vtkm/filter/WarpVector.hxx +++ b/vtkm/filter/WarpVector.hxx @@ -33,13 +33,11 @@ inline VTKM_CONT vtkm::cont::DataSet WarpVector::DoExecute( vtkm::filter::PolicyBase policy) { using vecType = vtkm::Vec; - auto vectorF = inDataSet.GetField(this->VectorFieldName, this->VectorFieldAssociation); + vtkm::cont::Field vectorF = + inDataSet.GetField(this->VectorFieldName, this->VectorFieldAssociation); vtkm::cont::ArrayHandle result; this->Worklet.Run( - field, - vtkm::filter::ApplyPolicy(vectorF, policy, vtkm::filter::FilterTraits()), - this->Scale, - result); + field, vtkm::filter::ApplyPolicy(vectorF, policy, *this), this->Scale, result); return CreateResult(inDataSet, result, this->GetOutputFieldName(), fieldMetadata); } diff --git a/vtkm/worklet/Tube.h b/vtkm/worklet/Tube.h index ebaae84aa..faf7fee21 100644 --- a/vtkm/worklet/Tube.h +++ b/vtkm/worklet/Tube.h @@ -536,18 +536,16 @@ public: VTKM_CONT void SetRadius(vtkm::FloatDefault r) { this->Radius = r; } - VTKM_CONT - void Run(const vtkm::cont::CoordinateSystem& coords, - const vtkm::cont::DynamicCellSet& cellset, - vtkm::cont::ArrayHandle& newPoints, - vtkm::cont::CellSetSingleType<>& newCells) + template + VTKM_CONT void Run(const vtkm::cont::ArrayHandle& coords, + const vtkm::cont::DynamicCellSet& cellset, + vtkm::cont::ArrayHandle& newPoints, + vtkm::cont::CellSetSingleType<>& newCells) { - using ExplCoordsType = vtkm::cont::ArrayHandle; using NormalsType = vtkm::cont::ArrayHandle; - if (!(coords.GetData().IsType() && - (cellset.IsSameType(vtkm::cont::CellSetExplicit<>()) || - cellset.IsSameType(vtkm::cont::CellSetSingleType<>())))) + if (!cellset.IsSameType(vtkm::cont::CellSetExplicit<>()) && + !cellset.IsSameType(vtkm::cont::CellSetSingleType<>())) { throw vtkm::cont::ErrorBadValue("Tube filter only supported for polyline data."); } @@ -574,11 +572,10 @@ public: vtkm::cont::Algorithm::ScanExclusive(segPerPolyline, segOffset); //Generate normals at each point on all polylines - ExplCoordsType inCoords = coords.GetData().Cast(); NormalsType normals; normals.Allocate(totalPolylinePts); vtkm::worklet::DispatcherMapTopology genNormalsDisp; - genNormalsDisp.Invoke(cellset, inCoords, polylinePtOffset, normals); + genNormalsDisp.Invoke(cellset, coords, polylinePtOffset, normals); //Generate the tube points newPoints.Allocate(totalTubePts); @@ -586,7 +583,7 @@ public: GeneratePoints genPts(this->Capping, this->NumSides, this->Radius); vtkm::worklet::DispatcherMapTopology genPtsDisp(genPts); genPtsDisp.Invoke(cellset, - inCoords, + coords, normals, tubePointOffsets, polylinePtOffset, diff --git a/vtkm/worklet/testing/UnitTestTube.cxx b/vtkm/worklet/testing/UnitTestTube.cxx index 36e8b9206..b98e392be 100644 --- a/vtkm/worklet/testing/UnitTestTube.cxx +++ b/vtkm/worklet/testing/UnitTestTube.cxx @@ -120,7 +120,10 @@ void TestTube(bool capEnds, vtkm::FloatDefault radius, vtkm::Id numSides, vtkm:: vtkm::worklet::Tube tubeWorklet(capEnds, numSides, radius); vtkm::cont::ArrayHandle newPoints; vtkm::cont::CellSetSingleType<> newCells; - tubeWorklet.Run(ds.GetCoordinateSystem(0), ds.GetCellSet(), newPoints, newCells); + tubeWorklet.Run(ds.GetCoordinateSystem(0).GetData().Cast>(), + ds.GetCellSet(), + newPoints, + newCells); VTKM_TEST_ASSERT(newPoints.GetNumberOfValues() == reqNumPts, "Wrong number of points in Tube worklet"); @@ -175,7 +178,11 @@ void TestLinearPolylines() vtkm::worklet::Tube tubeWorklet(capEnds, numSides, radius); vtkm::cont::ArrayHandle newPoints; vtkm::cont::CellSetSingleType<> newCells; - tubeWorklet.Run(ds.GetCoordinateSystem(0), ds.GetCellSet(), newPoints, newCells); + tubeWorklet.Run( + ds.GetCoordinateSystem(0).GetData().Cast>(), + ds.GetCellSet(), + newPoints, + newCells); VTKM_TEST_ASSERT(newPoints.GetNumberOfValues() == reqNumPts, "Wrong number of points in Tube worklet");