mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-09-08 13:23:51 +00:00
Update MapFieldMergeAverage/Permuation to use new CastAndCall
These functions now use `UnknownArrayHandle::CastAndCallWithExtractedArray` to reduce the number of times the worklet is run.
This commit is contained in:
parent
f90c2bfd0b
commit
06c59fed13
@ -18,3 +18,9 @@ you to use the functionality to transform the unknown array handle to a
|
||||
form of `ArrayHandle` that depends only on this base component type. This
|
||||
method internally uses a new `ArrayHandleRecombineVec` class, but this
|
||||
class is mostly intended for internal use by this class.
|
||||
|
||||
As an added convenience, `UnknownArrayHandle` now also provides the
|
||||
`CastAndCallWithExtractedArray` method. This method works like other
|
||||
`CastAndCall`s except that it uses the `ExtractArrayFromComponents` feature
|
||||
to allow you to handle most `ArrayHandle` types with few template
|
||||
instances.
|
||||
|
@ -68,20 +68,16 @@ public:
|
||||
|
||||
VTKM_EXEC_CONT vtkm::Id GetIndex() const { return this->Index; }
|
||||
|
||||
VTKM_EXEC_CONT RecombineVec& operator=(const RecombineVec& src)
|
||||
{
|
||||
this->DoCopy(src);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename T, typename = typename std::enable_if<vtkm::HasVecTraits<T>::value>::type>
|
||||
VTKM_EXEC_CONT RecombineVec& operator=(const T& src)
|
||||
{
|
||||
using VTraits = vtkm::VecTraits<T>;
|
||||
vtkm::IdComponent numComponents = VTraits::GetNumberOfComponents(src);
|
||||
if (numComponents > this->GetNumberOfComponents())
|
||||
{
|
||||
numComponents = this->GetNumberOfComponents();
|
||||
}
|
||||
for (vtkm::IdComponent cIndex = 0; cIndex < numComponents; ++cIndex)
|
||||
{
|
||||
this->Portals[cIndex].Set(this->Index, VTraits::GetComponent(src, cIndex));
|
||||
}
|
||||
|
||||
this->DoCopy(src);
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -94,6 +90,33 @@ public:
|
||||
this->CopyInto(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
VTKM_EXEC_CONT void DoCopy(const T& src)
|
||||
{
|
||||
using VTraits = vtkm::VecTraits<T>;
|
||||
vtkm::IdComponent numComponents = VTraits::GetNumberOfComponents(src);
|
||||
if (numComponents > 1)
|
||||
{
|
||||
if (numComponents > this->GetNumberOfComponents())
|
||||
{
|
||||
numComponents = this->GetNumberOfComponents();
|
||||
}
|
||||
for (vtkm::IdComponent cIndex = 0; cIndex < numComponents; ++cIndex)
|
||||
{
|
||||
this->Portals[cIndex].Set(this->Index, VTraits::GetComponent(src, cIndex));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Special case when copying from a scalar
|
||||
for (vtkm::IdComponent cIndex = 0; cIndex < this->GetNumberOfComponents(); ++cIndex)
|
||||
{
|
||||
this->Portals[cIndex].Set(this->Index, VTraits::GetComponent(src, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
@ -35,7 +35,7 @@ namespace cont
|
||||
|
||||
/// constructors for points / whole mesh
|
||||
VTKM_CONT
|
||||
Field::Field(std::string name, Association association, const vtkm::cont::VariantArrayHandle& data)
|
||||
Field::Field(std::string name, Association association, const vtkm::cont::UnknownArrayHandle& data)
|
||||
: Name(name)
|
||||
, FieldAssociation(association)
|
||||
, Data(data)
|
||||
|
@ -58,13 +58,13 @@ public:
|
||||
Field() = default;
|
||||
|
||||
VTKM_CONT
|
||||
Field(std::string name, Association association, const vtkm::cont::VariantArrayHandle& data);
|
||||
Field(std::string name, Association association, const vtkm::cont::UnknownArrayHandle& data);
|
||||
|
||||
template <typename T, typename Storage>
|
||||
VTKM_CONT Field(std::string name,
|
||||
Association association,
|
||||
const vtkm::cont::ArrayHandle<T, Storage>& data)
|
||||
: Field(name, association, vtkm::cont::VariantArrayHandle{ data })
|
||||
: Field(name, association, vtkm::cont::UnknownArrayHandle{ data })
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -21,31 +21,15 @@ namespace
|
||||
|
||||
struct DoMapFieldMerge
|
||||
{
|
||||
template <typename BaseComponentType>
|
||||
void operator()(BaseComponentType,
|
||||
const vtkm::cont::UnknownArrayHandle& input,
|
||||
template <typename InputArrayType>
|
||||
void operator()(const InputArrayType& input,
|
||||
const vtkm::worklet::internal::KeysBase& keys,
|
||||
vtkm::cont::UnknownArrayHandle& output,
|
||||
bool& called) const
|
||||
vtkm::cont::UnknownArrayHandle& output) const
|
||||
{
|
||||
if (!input.IsBaseComponentType<BaseComponentType>())
|
||||
{
|
||||
return;
|
||||
}
|
||||
using BaseComponentType = typename InputArrayType::ValueType::ComponentType;
|
||||
|
||||
output = input.NewInstanceBasic();
|
||||
output.Allocate(keys.GetInputRange());
|
||||
|
||||
vtkm::IdComponent numComponents = input.GetNumberOfComponentsFlat();
|
||||
for (vtkm::IdComponent cIndex = 0; cIndex < numComponents; ++cIndex)
|
||||
{
|
||||
vtkm::worklet::AverageByKey::Run(
|
||||
keys,
|
||||
input.ExtractComponent<BaseComponentType>(cIndex, vtkm::CopyFlag::On),
|
||||
output.ExtractComponent<BaseComponentType>(cIndex, vtkm::CopyFlag::Off));
|
||||
}
|
||||
|
||||
called = true;
|
||||
vtkm::worklet::AverageByKey::Run(
|
||||
keys, input, output.ExtractArrayFromComponents<BaseComponentType>(vtkm::CopyFlag::Off));
|
||||
}
|
||||
};
|
||||
|
||||
@ -57,23 +41,20 @@ bool vtkm::filter::MapFieldMergeAverage(const vtkm::cont::Field& inputField,
|
||||
{
|
||||
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
|
||||
|
||||
vtkm::cont::VariantArrayHandle outputArray;
|
||||
bool calledMap = false;
|
||||
vtkm::ListForEach(DoMapFieldMerge{},
|
||||
vtkm::TypeListScalarAll{},
|
||||
inputField.GetData(),
|
||||
keys,
|
||||
outputArray,
|
||||
calledMap);
|
||||
if (calledMap)
|
||||
vtkm::cont::UnknownArrayHandle outputArray = inputField.GetData().NewInstanceBasic();
|
||||
outputArray.Allocate(keys.GetInputRange());
|
||||
|
||||
try
|
||||
{
|
||||
inputField.GetData().CastAndCallWithExtractedArray(DoMapFieldMerge{}, keys, outputArray);
|
||||
outputField = vtkm::cont::Field(inputField.GetName(), inputField.GetAssociation(), outputArray);
|
||||
return true;
|
||||
}
|
||||
else
|
||||
catch (...)
|
||||
{
|
||||
VTKM_LOG_S(vtkm::cont::LogLevel::Warn, "Faild to map field " << inputField.GetName());
|
||||
return false;
|
||||
}
|
||||
return calledMap;
|
||||
}
|
||||
|
||||
bool vtkm::filter::MapFieldMergeAverage(const vtkm::cont::Field& inputField,
|
||||
|
@ -37,9 +37,12 @@ struct MapPermutationWorklet : vtkm::worklet::WorkletMapField
|
||||
|
||||
using ControlSignature = void(FieldIn permutationIndex, WholeArrayIn input, FieldOut output);
|
||||
|
||||
template <typename InputPortalType>
|
||||
VTKM_EXEC void operator()(vtkm::Id permutationIndex, InputPortalType inputPortal, T& output) const
|
||||
template <typename InputPortalType, typename OutputType>
|
||||
VTKM_EXEC void operator()(vtkm::Id permutationIndex,
|
||||
InputPortalType inputPortal,
|
||||
OutputType& output) const
|
||||
{
|
||||
VTKM_STATIC_ASSERT(vtkm::HasVecTraits<OutputType>::value);
|
||||
if ((permutationIndex >= 0) && (permutationIndex < inputPortal.GetNumberOfValues()))
|
||||
{
|
||||
output = inputPortal.Get(permutationIndex);
|
||||
@ -53,36 +56,21 @@ struct MapPermutationWorklet : vtkm::worklet::WorkletMapField
|
||||
|
||||
struct DoMapFieldPermutation
|
||||
{
|
||||
template <typename BaseComponentType>
|
||||
void operator()(BaseComponentType,
|
||||
const vtkm::cont::UnknownArrayHandle& input,
|
||||
template <typename InputArrayType>
|
||||
void operator()(const InputArrayType& input,
|
||||
const vtkm::cont::ArrayHandle<vtkm::Id>& permutation,
|
||||
vtkm::cont::UnknownArrayHandle& output,
|
||||
vtkm::Float64 invalidValue,
|
||||
bool& called) const
|
||||
vtkm::Float64 invalidValue) const
|
||||
{
|
||||
if (!input.IsBaseComponentType<BaseComponentType>())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
output = input.NewInstanceBasic();
|
||||
output.Allocate(permutation.GetNumberOfValues());
|
||||
|
||||
vtkm::IdComponent numComponents = input.GetNumberOfComponentsFlat();
|
||||
using BaseComponentType = typename InputArrayType::ValueType::ComponentType;
|
||||
|
||||
MapPermutationWorklet<BaseComponentType> worklet(
|
||||
vtkm::cont::internal::CastInvalidValue<BaseComponentType>(invalidValue));
|
||||
vtkm::cont::Invoker invoke;
|
||||
for (vtkm::IdComponent cIndex = 0; cIndex < numComponents; ++cIndex)
|
||||
{
|
||||
invoke(worklet,
|
||||
permutation,
|
||||
input.ExtractComponent<BaseComponentType>(cIndex, vtkm::CopyFlag::On),
|
||||
output.ExtractComponent<BaseComponentType>(cIndex, vtkm::CopyFlag::Off));
|
||||
}
|
||||
|
||||
called = true;
|
||||
vtkm::cont::Invoker{}(
|
||||
worklet,
|
||||
permutation,
|
||||
input,
|
||||
output.ExtractArrayFromComponents<BaseComponentType>(vtkm::CopyFlag::Off));
|
||||
}
|
||||
};
|
||||
|
||||
@ -96,24 +84,20 @@ VTKM_FILTER_COMMON_EXPORT VTKM_CONT bool vtkm::filter::MapFieldPermutation(
|
||||
{
|
||||
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
|
||||
|
||||
vtkm::cont::VariantArrayHandle outputArray;
|
||||
bool calledMap = false;
|
||||
vtkm::ListForEach(DoMapFieldPermutation{},
|
||||
vtkm::TypeListScalarAll{},
|
||||
inputField.GetData(),
|
||||
permutation,
|
||||
outputArray,
|
||||
invalidValue,
|
||||
calledMap);
|
||||
if (calledMap)
|
||||
vtkm::cont::UnknownArrayHandle outputArray = inputField.GetData().NewInstanceBasic();
|
||||
outputArray.Allocate(permutation.GetNumberOfValues());
|
||||
try
|
||||
{
|
||||
inputField.GetData().CastAndCallWithExtractedArray(
|
||||
DoMapFieldPermutation{}, permutation, outputArray, invalidValue);
|
||||
outputField = vtkm::cont::Field(inputField.GetName(), inputField.GetAssociation(), outputArray);
|
||||
return true;
|
||||
}
|
||||
else
|
||||
catch (...)
|
||||
{
|
||||
VTKM_LOG_S(vtkm::cont::LogLevel::Warn, "Faild to map field " << inputField.GetName());
|
||||
return false;
|
||||
}
|
||||
return calledMap;
|
||||
}
|
||||
|
||||
VTKM_FILTER_COMMON_EXPORT VTKM_CONT bool vtkm::filter::MapFieldPermutation(
|
||||
|
@ -27,36 +27,30 @@ struct AverageByKey
|
||||
struct AverageWorklet : public vtkm::worklet::WorkletReduceByKey
|
||||
{
|
||||
using ControlSignature = void(KeysIn keys, ValuesIn valuesIn, ReducedValuesOut averages);
|
||||
using ExecutionSignature = _3(_2);
|
||||
using ExecutionSignature = void(_2, _3);
|
||||
using InputDomain = _1;
|
||||
|
||||
template <typename ValuesVecType>
|
||||
VTKM_EXEC typename ValuesVecType::ComponentType operator()(const ValuesVecType& valuesIn) const
|
||||
template <typename ValuesVecType, typename OutType>
|
||||
VTKM_EXEC void operator()(const ValuesVecType& valuesIn, OutType& sum) const
|
||||
{
|
||||
using FieldType = typename ValuesVecType::ComponentType;
|
||||
FieldType sum = valuesIn[0];
|
||||
sum = valuesIn[0];
|
||||
for (vtkm::IdComponent index = 1; index < valuesIn.GetNumberOfComponents(); ++index)
|
||||
{
|
||||
FieldType component = valuesIn[index];
|
||||
// FieldType constructor is for when OutType is a Vec.
|
||||
// static_cast is for when FieldType is a small int that gets promoted to int32.
|
||||
sum = static_cast<FieldType>(sum + component);
|
||||
sum += valuesIn[index];
|
||||
}
|
||||
|
||||
// To get the average, we (of course) divide the sum by the amount of values, which is
|
||||
// returned from valuesIn.GetNumberOfComponents(). To do this, we need to cast the number of
|
||||
// components (returned as a vtkm::IdComponent) to a FieldType. This is a little more complex
|
||||
// than it first seems because FieldType might be a Vec type. If you just try a
|
||||
// static_cast<FieldType>(), it will use the constructor to FieldType which might be a Vec
|
||||
// constructor expecting the type of the component. So, get around this problem by first
|
||||
// casting to the component type of the field and then constructing a field value from that.
|
||||
// We use the VecTraits class to make this work regardless of whether FieldType is a real Vec
|
||||
// or just a scalar.
|
||||
using ComponentType = typename vtkm::VecTraits<FieldType>::ComponentType;
|
||||
// FieldType constructor is for when OutType is a Vec.
|
||||
// static_cast is for when FieldType is a small int that gets promoted to int32.
|
||||
return static_cast<FieldType>(
|
||||
sum / FieldType(static_cast<ComponentType>(valuesIn.GetNumberOfComponents())));
|
||||
// than it first seems because FieldType might be a Vec type or a Vec-like type that cannot
|
||||
// be constructed. To do this safely, we will do a component-wise divide.
|
||||
using VTraits = vtkm::VecTraits<OutType>;
|
||||
using ComponentType = typename VTraits::ComponentType;
|
||||
ComponentType divisor = static_cast<ComponentType>(valuesIn.GetNumberOfComponents());
|
||||
for (vtkm::IdComponent cIndex = 0; cIndex < VTraits::GetNumberOfComponents(sum); ++cIndex)
|
||||
{
|
||||
VTraits::SetComponent(sum, cIndex, VTraits::GetComponent(sum, cIndex) / divisor);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user