vtk-m/vtkm/cont/internal/MapArrayPermutation.cxx
Kenneth Moreland ac889b5004 Implement VecTraits class for all types
The `VecTraits` class allows templated functions, methods, and classes to
treat type arguments uniformly as `Vec` types or to otherwise differentiate
between scalar and vector types. This only works for types that `VecTraits`
is defined for.

The `VecTraits` templated class now has a default implementation that will
be used for any type that does not have a `VecTraits` specialization. This
removes many surprise compiler errors when using a template that, unknown
to you, has `VecTraits` in its implementation.

One potential issue is that if `VecTraits` gets defined for a new type, the
behavior of `VecTraits` could change for that type in backward-incompatible
ways. If `VecTraits` is used in a purely generic way, this should not be an
issue. However, if assumptions were made about the components and length,
this could cause problems.

Fixes #589
2023-03-16 12:59:38 -06:00

100 lines
2.9 KiB
C++

//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//============================================================================
#include <vtkm/cont/internal/CastInvalidValue.h>
#include <vtkm/cont/internal/MapArrayPermutation.h>
#include <vtkm/cont/ErrorBadType.h>
#include <vtkm/worklet/WorkletMapField.h>
namespace
{
template <typename T>
struct MapPermutationWorklet : vtkm::worklet::WorkletMapField
{
T InvalidValue;
explicit MapPermutationWorklet(T invalidValue)
: InvalidValue(invalidValue)
{
}
using ControlSignature = void(FieldIn permutationIndex, WholeArrayIn input, FieldOut output);
template <typename InputPortalType, typename OutputType>
VTKM_EXEC void operator()(vtkm::Id permutationIndex,
InputPortalType inputPortal,
OutputType& output) const
{
if ((permutationIndex >= 0) && (permutationIndex < inputPortal.GetNumberOfValues()))
{
output = inputPortal.Get(permutationIndex);
}
else
{
output = this->InvalidValue;
}
}
};
struct DoMapFieldPermutation
{
template <typename InputArrayType, typename PermutationArrayType>
void operator()(const InputArrayType& input,
const PermutationArrayType& permutation,
vtkm::cont::UnknownArrayHandle& output,
vtkm::Float64 invalidValue) const
{
using BaseComponentType = typename InputArrayType::ValueType::ComponentType;
MapPermutationWorklet<BaseComponentType> worklet(
vtkm::cont::internal::CastInvalidValue<BaseComponentType>(invalidValue));
vtkm::cont::Invoker{}(
worklet,
permutation,
input,
output.ExtractArrayFromComponents<BaseComponentType>(vtkm::CopyFlag::Off));
}
};
} // anonymous namespace
namespace vtkm
{
namespace cont
{
namespace internal
{
vtkm::cont::UnknownArrayHandle MapArrayPermutation(
const vtkm::cont::UnknownArrayHandle& inputArray,
const vtkm::cont::UnknownArrayHandle& permutation,
vtkm::Float64 invalidValue)
{
if (!permutation.IsBaseComponentType<vtkm::Id>())
{
throw vtkm::cont::ErrorBadType("Permutation array input to MapArrayPermutation must have "
"values of vtkm::Id. Reported type is " +
permutation.GetBaseComponentTypeName());
}
vtkm::cont::UnknownArrayHandle outputArray = inputArray.NewInstanceBasic();
outputArray.Allocate(permutation.GetNumberOfValues());
inputArray.CastAndCallWithExtractedArray(
DoMapFieldPermutation{}, permutation.ExtractComponent<vtkm::Id>(0), outputArray, invalidValue);
return outputArray;
}
}
}
} // namespace vtkm::cont::internal