diff --git a/docs/changelog/apply-policy-with-single-type.md b/docs/changelog/apply-policy-with-single-type.md index fc8e46fb5..0e21a9284 100644 --- a/docs/changelog/apply-policy-with-single-type.md +++ b/docs/changelog/apply-policy-with-single-type.md @@ -17,12 +17,15 @@ have to compile for this array once. This is done through a new version of `ApplyPolicy`. This version takes a type of the array as its first template argument, which must be specified. - + This requires having a list of potential storage to try. It will use that to construct an `ArrayHandleMultiplexer` containing all potential types. This list of storages comes from the policy. A `StorageList` item was added -to the policy. - +to the policy. It is also sometimes necessary for a filter to provide its +own special storage types. Thus, an `AdditionalFieldStorage` type was added +to `Filter` which is set to a `ListTag` of storage types that should be +added to those specified by the policy. + Types are automatically converted. So if you ask for a `vtkm::Float64` and field contains a `vtkm::Float32`, it will the array wrapped in an `ArrayHandleCast` to give the expected type. @@ -33,12 +36,12 @@ result is just going to follow the type of the field. ``` cpp template -inline VTKM_CONT vtkm::cont::DataSet CrossProduct::DoExecute( +inline VTKM_CONT vtkm::cont::DataSet MyFilter::DoExecute( const vtkm::cont::DataSet& inDataSet, const vtkm::cont::ArrayHandle& field, const vtkm::filter::FieldMetadata& fieldMetadata, vtkm::filter::PolicyBase policy) { vtkm::cont::CoordinateSystem coords = inDataSet.GetCoordianteSystem(); - auto coordsArray = vtkm::filter::ApplyPolicy(coords, policy); + auto coordsArray = vtkm::filter::ApplyPolicy(coords, policy, *this); ``` diff --git a/vtkm/filter/CrossProduct.hxx b/vtkm/filter/CrossProduct.hxx index 5bd261c86..0fd3ec76f 100644 --- a/vtkm/filter/CrossProduct.hxx +++ b/vtkm/filter/CrossProduct.hxx @@ -17,31 +17,6 @@ namespace vtkm namespace filter { -namespace detail -{ - -struct CrossProductFunctor -{ - vtkm::cont::Invoker& Invoke; - CrossProductFunctor(vtkm::cont::Invoker& invoke) - : Invoke(invoke) - { - } - - template - void operator()(const SecondaryFieldType& secondaryField, - const vtkm::cont::ArrayHandle, StorageType>& primaryField, - vtkm::cont::ArrayHandle>& output) const - { - this->Invoke(vtkm::worklet::CrossProduct{}, - primaryField, - vtkm::cont::make_ArrayHandleCast>(secondaryField), - output); - } -}; - -} // namespace detail - //----------------------------------------------------------------------------- inline VTKM_CONT CrossProduct::CrossProduct() : vtkm::filter::FilterField() @@ -57,40 +32,23 @@ inline VTKM_CONT CrossProduct::CrossProduct() template inline VTKM_CONT vtkm::cont::DataSet CrossProduct::DoExecute( const vtkm::cont::DataSet& inDataSet, - const vtkm::cont::ArrayHandle, StorageType>& field, + const vtkm::cont::ArrayHandle, StorageType>& primary, const vtkm::filter::FieldMetadata& fieldMetadata, vtkm::filter::PolicyBase policy) { + vtkm::cont::Field secondaryField; + if (this->UseCoordinateSystemAsSecondaryField) + { + secondaryField = inDataSet.GetCoordinateSystem(this->GetSecondaryCoordinateSystemIndex()); + } + else + { + secondaryField = inDataSet.GetField(this->SecondaryFieldName, this->SecondaryFieldAssociation); + } + auto secondary = vtkm::filter::ApplyPolicy>(secondaryField, policy, *this); - detail::CrossProductFunctor functor(this->Invoke); vtkm::cont::ArrayHandle> output; - try - { - if (this->UseCoordinateSystemAsSecondaryField) - { - vtkm::cont::CastAndCall( - inDataSet.GetCoordinateSystem(this->GetSecondaryCoordinateSystemIndex()), - functor, - field, - output); - } - else - { - using Traits = vtkm::filter::FilterTraits; - using TypeList = vtkm::ListTagBase>; - vtkm::filter::ApplyPolicy( - inDataSet.GetField(this->SecondaryFieldName, this->SecondaryFieldAssociation), - policy, - Traits()) - .ResetTypes(TypeList()) - .CastAndCall(functor, field, output); - } - } - catch (const vtkm::cont::Error&) - { - throw vtkm::cont::ErrorExecution("failed to execute."); - } - + this->Invoke(vtkm::worklet::CrossProduct{}, primary, secondary, output); return CreateResult(inDataSet, output, this->GetOutputFieldName(), fieldMetadata); } diff --git a/vtkm/filter/DotProduct.hxx b/vtkm/filter/DotProduct.hxx index 001cde97d..acf1a1c3d 100644 --- a/vtkm/filter/DotProduct.hxx +++ b/vtkm/filter/DotProduct.hxx @@ -15,31 +15,6 @@ namespace vtkm namespace filter { -namespace detail -{ - -struct DotProductFunctor -{ - vtkm::cont::Invoker& Invoke; - DotProductFunctor(vtkm::cont::Invoker& invoke) - : Invoke(invoke) - { - } - - template - void operator()(const SecondaryFieldType& secondaryField, - const vtkm::cont::ArrayHandle, StorageType>& primaryField, - vtkm::cont::ArrayHandle& output) const - { - this->Invoke(vtkm::worklet::DotProduct{}, - primaryField, - vtkm::cont::make_ArrayHandleCast>(secondaryField), - output); - } -}; - -} // namespace detail - //----------------------------------------------------------------------------- inline VTKM_CONT DotProduct::DotProduct() : vtkm::filter::FilterField() @@ -55,38 +30,23 @@ inline VTKM_CONT DotProduct::DotProduct() template inline VTKM_CONT vtkm::cont::DataSet DotProduct::DoExecute( const vtkm::cont::DataSet& inDataSet, - const vtkm::cont::ArrayHandle, StorageType>& field, + const vtkm::cont::ArrayHandle, StorageType>& primary, const vtkm::filter::FieldMetadata& fieldMetadata, vtkm::filter::PolicyBase policy) { - detail::DotProductFunctor functor(this->Invoke); + vtkm::cont::Field secondaryField; + if (this->UseCoordinateSystemAsSecondaryField) + { + secondaryField = inDataSet.GetCoordinateSystem(this->GetSecondaryCoordinateSystemIndex()); + } + else + { + secondaryField = inDataSet.GetField(this->SecondaryFieldName, this->SecondaryFieldAssociation); + } + auto secondary = vtkm::filter::ApplyPolicy>(secondaryField, policy, *this); + vtkm::cont::ArrayHandle output; - try - { - if (this->UseCoordinateSystemAsSecondaryField) - { - vtkm::cont::CastAndCall( - inDataSet.GetCoordinateSystem(this->GetSecondaryCoordinateSystemIndex()), - functor, - field, - output); - } - else - { - using Traits = vtkm::filter::FilterTraits; - using TypeList = vtkm::ListTagBase>; - vtkm::filter::ApplyPolicy( - inDataSet.GetField(this->SecondaryFieldName, this->SecondaryFieldAssociation), - policy, - Traits()) - .ResetTypes(TypeList()) - .CastAndCall(functor, field, output); - } - } - catch (const vtkm::cont::Error&) - { - throw vtkm::cont::ErrorExecution("failed to execute."); - } + this->Invoke(vtkm::worklet::DotProduct{}, primary, secondary, output); return CreateResult(inDataSet, output, this->GetOutputFieldName(), fieldMetadata); } diff --git a/vtkm/filter/Filter.h b/vtkm/filter/Filter.h index bcd2373de..a7f021097 100644 --- a/vtkm/filter/Filter.h +++ b/vtkm/filter/Filter.h @@ -179,7 +179,6 @@ public: VTKM_CONT ~Filter(); - //@{ /// \brief Specify which subset of types a filter supports. /// /// A filter is able to state what subset of types it supports @@ -187,6 +186,16 @@ public: /// filter accepts all types specified by the users provided policy using SupportedTypes = vtkm::ListTagUniversal; + /// \brief Specify which additional field storage to support. + /// + /// When a filter gets a field value from a DataSet, it has to determine what type + /// of storage the array has. Typically this is taken from the policy passed to + /// the filter's execute. In some cases it is useful to support additional types. + /// For example, the filter might make sense to support ArrayHandleIndex or + /// ArrayHandleConstant. If so, the storage of those additional types should be + /// listed here. + using AdditionalFieldStorage = vtkm::ListTagEmpty; + //@{ /// \brief Specify which fields get passed from input to output. /// diff --git a/vtkm/filter/FilterTraits.h b/vtkm/filter/FilterTraits.h index c5fd8ae00..a3714ab1f 100644 --- a/vtkm/filter/FilterTraits.h +++ b/vtkm/filter/FilterTraits.h @@ -35,6 +35,7 @@ struct FilterTraits { using InputFieldTypeList = decltype(detail::as_list(std::declval())); + using AdditionalFieldStorage = typename Filter::AdditionalFieldStorage; }; template diff --git a/vtkm/filter/PolicyBase.h b/vtkm/filter/PolicyBase.h index e2f50da34..51b6b102c 100644 --- a/vtkm/filter/PolicyBase.h +++ b/vtkm/filter/PolicyBase.h @@ -238,12 +238,19 @@ VTKM_CONT vtkm::cont::VariantArrayHandleBase -VTKM_CONT internal::ArrayHandleMultiplexerForStorageList -ApplyPolicy(const vtkm::cont::Field& field, const vtkm::filter::PolicyBase&) +template +VTKM_CONT internal::ArrayHandleMultiplexerForStorageList< + T, + vtkm::ListTagJoin::AdditionalFieldStorage, + typename DerivedPolicy::StorageList>> +ApplyPolicy(const vtkm::cont::Field& field, + vtkm::filter::PolicyBase, + const FilterType&) { - using ArrayHandleMultiplexerType = - internal::ArrayHandleMultiplexerForStorageList; + using ArrayHandleMultiplexerType = internal::ArrayHandleMultiplexerForStorageList< + T, + vtkm::ListTagJoin>; return field.GetData().AsMultiplexer(); } @@ -253,8 +260,8 @@ VTKM_CONT vtkm::cont::VariantArrayHandleBase::InputFieldTypeList>::TypeList> ApplyPolicy(const vtkm::cont::Field& field, - const vtkm::filter::PolicyBase&, - const vtkm::filter::FilterTraits&) + vtkm::filter::PolicyBase, + vtkm::filter::FilterTraits) { using FilterTypes = typename vtkm::filter::FilterTraits::InputFieldTypeList; using TypeList = @@ -266,9 +273,7 @@ ApplyPolicy(const vtkm::cont::Field& field, template VTKM_CONT vtkm::cont::VariantArrayHandleBase< typename vtkm::filter::DeduceFilterFieldTypes::TypeList> -ApplyPolicy(const vtkm::cont::Field& field, - const vtkm::filter::PolicyBase&, - const ListOfTypes&) +ApplyPolicy(const vtkm::cont::Field& field, vtkm::filter::PolicyBase, ListOfTypes) { using TypeList = typename vtkm::filter::DeduceFilterFieldTypes::TypeList; @@ -279,7 +284,7 @@ ApplyPolicy(const vtkm::cont::Field& field, template VTKM_CONT vtkm::cont::DynamicCellSetBase ApplyPolicy( const vtkm::cont::DynamicCellSet& cellset, - const vtkm::filter::PolicyBase&) + vtkm::filter::PolicyBase) { using CellSetList = typename DerivedPolicy::AllCellSetList; return cellset.ResetCellSetList(CellSetList()); @@ -289,7 +294,7 @@ VTKM_CONT vtkm::cont::DynamicCellSetBase template VTKM_CONT vtkm::cont::DynamicCellSetBase ApplyPolicyStructured(const vtkm::cont::DynamicCellSet& cellset, - const vtkm::filter::PolicyBase&) + vtkm::filter::PolicyBase) { using CellSetList = typename DerivedPolicy::StructuredCellSetList; return cellset.ResetCellSetList(CellSetList()); @@ -299,7 +304,7 @@ ApplyPolicyStructured(const vtkm::cont::DynamicCellSet& cellset, template VTKM_CONT vtkm::cont::DynamicCellSetBase ApplyPolicyUnstructured(const vtkm::cont::DynamicCellSet& cellset, - const vtkm::filter::PolicyBase&) + vtkm::filter::PolicyBase) { using CellSetList = typename DerivedPolicy::UnstructuredCellSetList; return cellset.ResetCellSetList(CellSetList()); @@ -308,15 +313,14 @@ ApplyPolicyUnstructured(const vtkm::cont::DynamicCellSet& cellset, //----------------------------------------------------------------------------- template VTKM_CONT vtkm::cont::SerializableField -MakeSerializableField(const vtkm::filter::PolicyBase&) + MakeSerializableField(vtkm::filter::PolicyBase) { return {}; } template VTKM_CONT vtkm::cont::SerializableField -MakeSerializableField(const vtkm::cont::Field& field, - const vtkm::filter::PolicyBase&) +MakeSerializableField(const vtkm::cont::Field& field, vtkm::filter::PolicyBase) { return vtkm::cont::SerializableField{ field }; } @@ -324,7 +328,7 @@ MakeSerializableField(const vtkm::cont::Field& field, template VTKM_CONT vtkm::cont::SerializableDataSet -MakeSerializableDataSet(const vtkm::filter::PolicyBase&) + MakeSerializableDataSet(vtkm::filter::PolicyBase) { return {}; } @@ -332,8 +336,7 @@ MakeSerializableDataSet(const vtkm::filter::PolicyBase&) template VTKM_CONT vtkm::cont::SerializableDataSet -MakeSerializableDataSet(const vtkm::cont::DataSet& dataset, - const vtkm::filter::PolicyBase&) +MakeSerializableDataSet(const vtkm::cont::DataSet& dataset, vtkm::filter::PolicyBase) { return vtkm::cont::SerializableDataSet{ dataset }; diff --git a/vtkm/filter/Tube.hxx b/vtkm/filter/Tube.hxx index cec66cf66..ceca1167f 100644 --- a/vtkm/filter/Tube.hxx +++ b/vtkm/filter/Tube.hxx @@ -12,6 +12,8 @@ #include #include +#include + namespace vtkm { namespace filter @@ -27,18 +29,17 @@ inline VTKM_CONT Tube::Tube() //----------------------------------------------------------------------------- template inline VTKM_CONT vtkm::cont::DataSet Tube::DoExecute(const vtkm::cont::DataSet& input, - vtkm::filter::PolicyBase) + vtkm::filter::PolicyBase policy) { this->Worklet.SetCapping(this->Capping); this->Worklet.SetNumberOfSides(this->NumberOfSides); this->Worklet.SetRadius(this->Radius); + auto originalPoints = vtkm::filter::ApplyPolicy( + input.GetCoordinateSystem(this->GetActiveCoordinateSystemIndex()), policy, *this); vtkm::cont::ArrayHandle newPoints; vtkm::cont::CellSetSingleType<> newCells; - this->Worklet.Run(input.GetCoordinateSystem(this->GetActiveCoordinateSystemIndex()), - input.GetCellSet(), - newPoints, - newCells); + this->Worklet.Run(originalPoints, input.GetCellSet(), newPoints, newCells); vtkm::cont::DataSet outData; vtkm::cont::CoordinateSystem outCoords("coordinates", newPoints); diff --git a/vtkm/filter/WarpScalar.h b/vtkm/filter/WarpScalar.h index 0ba998081..5a9808d89 100644 --- a/vtkm/filter/WarpScalar.h +++ b/vtkm/filter/WarpScalar.h @@ -32,6 +32,11 @@ public: // WarpScalar can only applies to Float and Double Vec3 arrays using SupportedTypes = vtkm::TypeListTagFieldVec3; + // WarpScalar often operates on a constant normal value + using AdditionalFieldStorage = + vtkm::ListTagBase::StorageTag, + vtkm::cont::ArrayHandleConstant::StorageTag>; + VTKM_CONT WarpScalar(vtkm::FloatDefault scaleAmount); diff --git a/vtkm/filter/WarpScalar.hxx b/vtkm/filter/WarpScalar.hxx index 124e0c1ca..1527ddd69 100644 --- a/vtkm/filter/WarpScalar.hxx +++ b/vtkm/filter/WarpScalar.hxx @@ -35,15 +35,16 @@ inline VTKM_CONT vtkm::cont::DataSet WarpScalar::DoExecute( vtkm::filter::PolicyBase policy) { using vecType = vtkm::Vec; - auto normalF = inDataSet.GetField(this->NormalFieldName, this->NormalFieldAssociation); - auto sfF = inDataSet.GetField(this->ScalarFactorFieldName, this->ScalarFactorFieldAssociation); + vtkm::cont::Field normalF = + inDataSet.GetField(this->NormalFieldName, this->NormalFieldAssociation); + vtkm::cont::Field sfF = + inDataSet.GetField(this->ScalarFactorFieldName, this->ScalarFactorFieldAssociation); vtkm::cont::ArrayHandle result; - this->Worklet.Run( - field, - vtkm::filter::ApplyPolicy(normalF, policy, vtkm::filter::FilterTraits()), - vtkm::filter::ApplyPolicy(sfF, policy, vtkm::TypeListTagFieldScalar{}), - this->ScaleAmount, - result); + this->Worklet.Run(field, + vtkm::filter::ApplyPolicy(normalF, policy, *this), + vtkm::filter::ApplyPolicy(sfF, policy, *this), + this->ScaleAmount, + result); return CreateResult(inDataSet, result, this->GetOutputFieldName(), fieldMetadata); } diff --git a/vtkm/filter/WarpVector.h b/vtkm/filter/WarpVector.h index 243e54d76..06e5bf103 100644 --- a/vtkm/filter/WarpVector.h +++ b/vtkm/filter/WarpVector.h @@ -30,6 +30,9 @@ class WarpVector : public vtkm::filter::FilterField { public: using SupportedTypes = vtkm::TypeListTagFieldVec3; + using AdditionalFieldStorage = + vtkm::ListTagBase::StorageTag, + vtkm::cont::ArrayHandleConstant::StorageTag>; VTKM_CONT WarpVector(vtkm::FloatDefault scale); diff --git a/vtkm/filter/WarpVector.hxx b/vtkm/filter/WarpVector.hxx index b86e94ce5..df6492ea5 100644 --- a/vtkm/filter/WarpVector.hxx +++ b/vtkm/filter/WarpVector.hxx @@ -33,13 +33,11 @@ inline VTKM_CONT vtkm::cont::DataSet WarpVector::DoExecute( vtkm::filter::PolicyBase policy) { using vecType = vtkm::Vec; - auto vectorF = inDataSet.GetField(this->VectorFieldName, this->VectorFieldAssociation); + vtkm::cont::Field vectorF = + inDataSet.GetField(this->VectorFieldName, this->VectorFieldAssociation); vtkm::cont::ArrayHandle result; this->Worklet.Run( - field, - vtkm::filter::ApplyPolicy(vectorF, policy, vtkm::filter::FilterTraits()), - this->Scale, - result); + field, vtkm::filter::ApplyPolicy(vectorF, policy, *this), this->Scale, result); return CreateResult(inDataSet, result, this->GetOutputFieldName(), fieldMetadata); } diff --git a/vtkm/worklet/Tube.h b/vtkm/worklet/Tube.h index ebaae84aa..faf7fee21 100644 --- a/vtkm/worklet/Tube.h +++ b/vtkm/worklet/Tube.h @@ -536,18 +536,16 @@ public: VTKM_CONT void SetRadius(vtkm::FloatDefault r) { this->Radius = r; } - VTKM_CONT - void Run(const vtkm::cont::CoordinateSystem& coords, - const vtkm::cont::DynamicCellSet& cellset, - vtkm::cont::ArrayHandle& newPoints, - vtkm::cont::CellSetSingleType<>& newCells) + template + VTKM_CONT void Run(const vtkm::cont::ArrayHandle& coords, + const vtkm::cont::DynamicCellSet& cellset, + vtkm::cont::ArrayHandle& newPoints, + vtkm::cont::CellSetSingleType<>& newCells) { - using ExplCoordsType = vtkm::cont::ArrayHandle; using NormalsType = vtkm::cont::ArrayHandle; - if (!(coords.GetData().IsType() && - (cellset.IsSameType(vtkm::cont::CellSetExplicit<>()) || - cellset.IsSameType(vtkm::cont::CellSetSingleType<>())))) + if (!cellset.IsSameType(vtkm::cont::CellSetExplicit<>()) && + !cellset.IsSameType(vtkm::cont::CellSetSingleType<>())) { throw vtkm::cont::ErrorBadValue("Tube filter only supported for polyline data."); } @@ -574,11 +572,10 @@ public: vtkm::cont::Algorithm::ScanExclusive(segPerPolyline, segOffset); //Generate normals at each point on all polylines - ExplCoordsType inCoords = coords.GetData().Cast(); NormalsType normals; normals.Allocate(totalPolylinePts); vtkm::worklet::DispatcherMapTopology genNormalsDisp; - genNormalsDisp.Invoke(cellset, inCoords, polylinePtOffset, normals); + genNormalsDisp.Invoke(cellset, coords, polylinePtOffset, normals); //Generate the tube points newPoints.Allocate(totalTubePts); @@ -586,7 +583,7 @@ public: GeneratePoints genPts(this->Capping, this->NumSides, this->Radius); vtkm::worklet::DispatcherMapTopology genPtsDisp(genPts); genPtsDisp.Invoke(cellset, - inCoords, + coords, normals, tubePointOffsets, polylinePtOffset, diff --git a/vtkm/worklet/testing/UnitTestTube.cxx b/vtkm/worklet/testing/UnitTestTube.cxx index 36e8b9206..b98e392be 100644 --- a/vtkm/worklet/testing/UnitTestTube.cxx +++ b/vtkm/worklet/testing/UnitTestTube.cxx @@ -120,7 +120,10 @@ void TestTube(bool capEnds, vtkm::FloatDefault radius, vtkm::Id numSides, vtkm:: vtkm::worklet::Tube tubeWorklet(capEnds, numSides, radius); vtkm::cont::ArrayHandle newPoints; vtkm::cont::CellSetSingleType<> newCells; - tubeWorklet.Run(ds.GetCoordinateSystem(0), ds.GetCellSet(), newPoints, newCells); + tubeWorklet.Run(ds.GetCoordinateSystem(0).GetData().Cast>(), + ds.GetCellSet(), + newPoints, + newCells); VTKM_TEST_ASSERT(newPoints.GetNumberOfValues() == reqNumPts, "Wrong number of points in Tube worklet"); @@ -175,7 +178,11 @@ void TestLinearPolylines() vtkm::worklet::Tube tubeWorklet(capEnds, numSides, radius); vtkm::cont::ArrayHandle newPoints; vtkm::cont::CellSetSingleType<> newCells; - tubeWorklet.Run(ds.GetCoordinateSystem(0), ds.GetCellSet(), newPoints, newCells); + tubeWorklet.Run( + ds.GetCoordinateSystem(0).GetData().Cast>(), + ds.GetCellSet(), + newPoints, + newCells); VTKM_TEST_ASSERT(newPoints.GetNumberOfValues() == reqNumPts, "Wrong number of points in Tube worklet");