Update MapFieldMergeAverage/Permuation to use new CastAndCall

These functions now use
`UnknownArrayHandle::CastAndCallWithExtractedArray` to reduce the number
of times the worklet is run.
This commit is contained in:
Kenneth Moreland 2021-01-05 13:02:31 -07:00
parent f90c2bfd0b
commit 06c59fed13
7 changed files with 93 additions and 105 deletions

@ -18,3 +18,9 @@ you to use the functionality to transform the unknown array handle to a
form of `ArrayHandle` that depends only on this base component type. This
method internally uses a new `ArrayHandleRecombineVec` class, but this
class is mostly intended for internal use by this class.
As an added convenience, `UnknownArrayHandle` now also provides the
`CastAndCallWithExtractedArray` method. This method works like other
`CastAndCall`s except that it uses the `ExtractArrayFromComponents` feature
to allow you to handle most `ArrayHandle` types with few template
instances.

@ -68,20 +68,16 @@ public:
VTKM_EXEC_CONT vtkm::Id GetIndex() const { return this->Index; }
VTKM_EXEC_CONT RecombineVec& operator=(const RecombineVec& src)
{
this->DoCopy(src);
return *this;
}
template <typename T, typename = typename std::enable_if<vtkm::HasVecTraits<T>::value>::type>
VTKM_EXEC_CONT RecombineVec& operator=(const T& src)
{
using VTraits = vtkm::VecTraits<T>;
vtkm::IdComponent numComponents = VTraits::GetNumberOfComponents(src);
if (numComponents > this->GetNumberOfComponents())
{
numComponents = this->GetNumberOfComponents();
}
for (vtkm::IdComponent cIndex = 0; cIndex < numComponents; ++cIndex)
{
this->Portals[cIndex].Set(this->Index, VTraits::GetComponent(src, cIndex));
}
this->DoCopy(src);
return *this;
}
@ -94,6 +90,33 @@ public:
this->CopyInto(result);
return result;
}
private:
template <typename T>
VTKM_EXEC_CONT void DoCopy(const T& src)
{
using VTraits = vtkm::VecTraits<T>;
vtkm::IdComponent numComponents = VTraits::GetNumberOfComponents(src);
if (numComponents > 1)
{
if (numComponents > this->GetNumberOfComponents())
{
numComponents = this->GetNumberOfComponents();
}
for (vtkm::IdComponent cIndex = 0; cIndex < numComponents; ++cIndex)
{
this->Portals[cIndex].Set(this->Index, VTraits::GetComponent(src, cIndex));
}
}
else
{
// Special case when copying from a scalar
for (vtkm::IdComponent cIndex = 0; cIndex < this->GetNumberOfComponents(); ++cIndex)
{
this->Portals[cIndex].Set(this->Index, VTraits::GetComponent(src, 0));
}
}
}
};
} // namespace internal

@ -35,7 +35,7 @@ namespace cont
/// constructors for points / whole mesh
VTKM_CONT
Field::Field(std::string name, Association association, const vtkm::cont::VariantArrayHandle& data)
Field::Field(std::string name, Association association, const vtkm::cont::UnknownArrayHandle& data)
: Name(name)
, FieldAssociation(association)
, Data(data)

@ -58,13 +58,13 @@ public:
Field() = default;
VTKM_CONT
Field(std::string name, Association association, const vtkm::cont::VariantArrayHandle& data);
Field(std::string name, Association association, const vtkm::cont::UnknownArrayHandle& data);
template <typename T, typename Storage>
VTKM_CONT Field(std::string name,
Association association,
const vtkm::cont::ArrayHandle<T, Storage>& data)
: Field(name, association, vtkm::cont::VariantArrayHandle{ data })
: Field(name, association, vtkm::cont::UnknownArrayHandle{ data })
{
}

@ -21,31 +21,15 @@ namespace
struct DoMapFieldMerge
{
template <typename BaseComponentType>
void operator()(BaseComponentType,
const vtkm::cont::UnknownArrayHandle& input,
template <typename InputArrayType>
void operator()(const InputArrayType& input,
const vtkm::worklet::internal::KeysBase& keys,
vtkm::cont::UnknownArrayHandle& output,
bool& called) const
vtkm::cont::UnknownArrayHandle& output) const
{
if (!input.IsBaseComponentType<BaseComponentType>())
{
return;
}
using BaseComponentType = typename InputArrayType::ValueType::ComponentType;
output = input.NewInstanceBasic();
output.Allocate(keys.GetInputRange());
vtkm::IdComponent numComponents = input.GetNumberOfComponentsFlat();
for (vtkm::IdComponent cIndex = 0; cIndex < numComponents; ++cIndex)
{
vtkm::worklet::AverageByKey::Run(
keys,
input.ExtractComponent<BaseComponentType>(cIndex, vtkm::CopyFlag::On),
output.ExtractComponent<BaseComponentType>(cIndex, vtkm::CopyFlag::Off));
}
called = true;
vtkm::worklet::AverageByKey::Run(
keys, input, output.ExtractArrayFromComponents<BaseComponentType>(vtkm::CopyFlag::Off));
}
};
@ -57,23 +41,20 @@ bool vtkm::filter::MapFieldMergeAverage(const vtkm::cont::Field& inputField,
{
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
vtkm::cont::VariantArrayHandle outputArray;
bool calledMap = false;
vtkm::ListForEach(DoMapFieldMerge{},
vtkm::TypeListScalarAll{},
inputField.GetData(),
keys,
outputArray,
calledMap);
if (calledMap)
vtkm::cont::UnknownArrayHandle outputArray = inputField.GetData().NewInstanceBasic();
outputArray.Allocate(keys.GetInputRange());
try
{
inputField.GetData().CastAndCallWithExtractedArray(DoMapFieldMerge{}, keys, outputArray);
outputField = vtkm::cont::Field(inputField.GetName(), inputField.GetAssociation(), outputArray);
return true;
}
else
catch (...)
{
VTKM_LOG_S(vtkm::cont::LogLevel::Warn, "Faild to map field " << inputField.GetName());
return false;
}
return calledMap;
}
bool vtkm::filter::MapFieldMergeAverage(const vtkm::cont::Field& inputField,

@ -37,9 +37,12 @@ struct MapPermutationWorklet : vtkm::worklet::WorkletMapField
using ControlSignature = void(FieldIn permutationIndex, WholeArrayIn input, FieldOut output);
template <typename InputPortalType>
VTKM_EXEC void operator()(vtkm::Id permutationIndex, InputPortalType inputPortal, T& output) const
template <typename InputPortalType, typename OutputType>
VTKM_EXEC void operator()(vtkm::Id permutationIndex,
InputPortalType inputPortal,
OutputType& output) const
{
VTKM_STATIC_ASSERT(vtkm::HasVecTraits<OutputType>::value);
if ((permutationIndex >= 0) && (permutationIndex < inputPortal.GetNumberOfValues()))
{
output = inputPortal.Get(permutationIndex);
@ -53,36 +56,21 @@ struct MapPermutationWorklet : vtkm::worklet::WorkletMapField
struct DoMapFieldPermutation
{
template <typename BaseComponentType>
void operator()(BaseComponentType,
const vtkm::cont::UnknownArrayHandle& input,
template <typename InputArrayType>
void operator()(const InputArrayType& input,
const vtkm::cont::ArrayHandle<vtkm::Id>& permutation,
vtkm::cont::UnknownArrayHandle& output,
vtkm::Float64 invalidValue,
bool& called) const
vtkm::Float64 invalidValue) const
{
if (!input.IsBaseComponentType<BaseComponentType>())
{
return;
}
output = input.NewInstanceBasic();
output.Allocate(permutation.GetNumberOfValues());
vtkm::IdComponent numComponents = input.GetNumberOfComponentsFlat();
using BaseComponentType = typename InputArrayType::ValueType::ComponentType;
MapPermutationWorklet<BaseComponentType> worklet(
vtkm::cont::internal::CastInvalidValue<BaseComponentType>(invalidValue));
vtkm::cont::Invoker invoke;
for (vtkm::IdComponent cIndex = 0; cIndex < numComponents; ++cIndex)
{
invoke(worklet,
permutation,
input.ExtractComponent<BaseComponentType>(cIndex, vtkm::CopyFlag::On),
output.ExtractComponent<BaseComponentType>(cIndex, vtkm::CopyFlag::Off));
}
called = true;
vtkm::cont::Invoker{}(
worklet,
permutation,
input,
output.ExtractArrayFromComponents<BaseComponentType>(vtkm::CopyFlag::Off));
}
};
@ -96,24 +84,20 @@ VTKM_FILTER_COMMON_EXPORT VTKM_CONT bool vtkm::filter::MapFieldPermutation(
{
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
vtkm::cont::VariantArrayHandle outputArray;
bool calledMap = false;
vtkm::ListForEach(DoMapFieldPermutation{},
vtkm::TypeListScalarAll{},
inputField.GetData(),
permutation,
outputArray,
invalidValue,
calledMap);
if (calledMap)
vtkm::cont::UnknownArrayHandle outputArray = inputField.GetData().NewInstanceBasic();
outputArray.Allocate(permutation.GetNumberOfValues());
try
{
inputField.GetData().CastAndCallWithExtractedArray(
DoMapFieldPermutation{}, permutation, outputArray, invalidValue);
outputField = vtkm::cont::Field(inputField.GetName(), inputField.GetAssociation(), outputArray);
return true;
}
else
catch (...)
{
VTKM_LOG_S(vtkm::cont::LogLevel::Warn, "Faild to map field " << inputField.GetName());
return false;
}
return calledMap;
}
VTKM_FILTER_COMMON_EXPORT VTKM_CONT bool vtkm::filter::MapFieldPermutation(

@ -27,36 +27,30 @@ struct AverageByKey
struct AverageWorklet : public vtkm::worklet::WorkletReduceByKey
{
using ControlSignature = void(KeysIn keys, ValuesIn valuesIn, ReducedValuesOut averages);
using ExecutionSignature = _3(_2);
using ExecutionSignature = void(_2, _3);
using InputDomain = _1;
template <typename ValuesVecType>
VTKM_EXEC typename ValuesVecType::ComponentType operator()(const ValuesVecType& valuesIn) const
template <typename ValuesVecType, typename OutType>
VTKM_EXEC void operator()(const ValuesVecType& valuesIn, OutType& sum) const
{
using FieldType = typename ValuesVecType::ComponentType;
FieldType sum = valuesIn[0];
sum = valuesIn[0];
for (vtkm::IdComponent index = 1; index < valuesIn.GetNumberOfComponents(); ++index)
{
FieldType component = valuesIn[index];
// FieldType constructor is for when OutType is a Vec.
// static_cast is for when FieldType is a small int that gets promoted to int32.
sum = static_cast<FieldType>(sum + component);
sum += valuesIn[index];
}
// To get the average, we (of course) divide the sum by the amount of values, which is
// returned from valuesIn.GetNumberOfComponents(). To do this, we need to cast the number of
// components (returned as a vtkm::IdComponent) to a FieldType. This is a little more complex
// than it first seems because FieldType might be a Vec type. If you just try a
// static_cast<FieldType>(), it will use the constructor to FieldType which might be a Vec
// constructor expecting the type of the component. So, get around this problem by first
// casting to the component type of the field and then constructing a field value from that.
// We use the VecTraits class to make this work regardless of whether FieldType is a real Vec
// or just a scalar.
using ComponentType = typename vtkm::VecTraits<FieldType>::ComponentType;
// FieldType constructor is for when OutType is a Vec.
// static_cast is for when FieldType is a small int that gets promoted to int32.
return static_cast<FieldType>(
sum / FieldType(static_cast<ComponentType>(valuesIn.GetNumberOfComponents())));
// than it first seems because FieldType might be a Vec type or a Vec-like type that cannot
// be constructed. To do this safely, we will do a component-wise divide.
using VTraits = vtkm::VecTraits<OutType>;
using ComponentType = typename VTraits::ComponentType;
ComponentType divisor = static_cast<ComponentType>(valuesIn.GetNumberOfComponents());
for (vtkm::IdComponent cIndex = 0; cIndex < VTraits::GetNumberOfComponents(sum); ++cIndex)
{
VTraits::SetComponent(sum, cIndex, VTraits::GetComponent(sum, cIndex) / divisor);
}
}
};