Fix dot product type promotion

Update the dot product functions to use auto returns to capture proper
type promotion.
This commit is contained in:
Kenneth Moreland 2023-02-07 21:57:53 -07:00
parent d5ba7f4d59
commit 8befbc47c7

@ -1426,26 +1426,11 @@ static inline VTKM_EXEC_CONT vtkm::VecCConst<T> make_VecC(const T* array, vtkm::
namespace detail
{
template <typename T>
struct DotType
{
//results when < 32bit can be float if somehow we are using float16/float8, otherwise is
// int32 or uint32 depending on if it signed or not.
using float_type = vtkm::Float32;
using integer_type =
typename std::conditional<std::is_signed<T>::value, vtkm::Int32, vtkm::UInt32>::type;
using promote_type =
typename std::conditional<std::is_integral<T>::value, integer_type, float_type>::type;
using type =
typename std::conditional<(sizeof(T) < sizeof(vtkm::Float32)), promote_type, T>::type;
};
template <typename T>
static inline VTKM_EXEC_CONT typename DotType<typename T::ComponentType>::type vec_dot(const T& a,
const T& b)
static inline VTKM_EXEC_CONT auto vec_dot(const T& a, const T& b)
{
using U = typename DotType<typename T::ComponentType>::type;
U result = a[0] * b[0];
auto result = a[0] * b[0];
for (vtkm::IdComponent i = 1; i < a.GetNumberOfComponents(); ++i)
{
result = result + a[i] * b[i];
@ -1453,50 +1438,44 @@ static inline VTKM_EXEC_CONT typename DotType<typename T::ComponentType>::type v
return result;
}
template <typename T, vtkm::IdComponent Size>
static inline VTKM_EXEC_CONT typename DotType<T>::type vec_dot(const vtkm::Vec<T, Size>& a,
const vtkm::Vec<T, Size>& b)
static inline VTKM_EXEC_CONT auto vec_dot(const vtkm::Vec<T, Size>& a, const vtkm::Vec<T, Size>& b)
{
using U = typename DotType<T>::type;
U result = a[0] * b[0];
auto result = a[0] * b[0];
for (vtkm::IdComponent i = 1; i < Size; ++i)
{
result = result + a[i] * b[i];
}
return result;
}
}
} // namespace detail
template <typename T>
static inline VTKM_EXEC_CONT auto Dot(const T& a, const T& b) -> decltype(detail::vec_dot(a, b))
static inline VTKM_EXEC_CONT auto Dot(const T& a, const T& b)
{
return detail::vec_dot(a, b);
}
template <typename T>
static inline VTKM_EXEC_CONT typename detail::DotType<T>::type Dot(const vtkm::Vec<T, 2>& a,
const vtkm::Vec<T, 2>& b)
static inline VTKM_EXEC_CONT auto Dot(const vtkm::Vec<T, 2>& a, const vtkm::Vec<T, 2>& b)
{
return (a[0] * b[0]) + (a[1] * b[1]);
}
template <typename T>
static inline VTKM_EXEC_CONT typename detail::DotType<T>::type Dot(const vtkm::Vec<T, 3>& a,
const vtkm::Vec<T, 3>& b)
static inline VTKM_EXEC_CONT auto Dot(const vtkm::Vec<T, 3>& a, const vtkm::Vec<T, 3>& b)
{
return (a[0] * b[0]) + (a[1] * b[1]) + (a[2] * b[2]);
}
template <typename T>
static inline VTKM_EXEC_CONT typename detail::DotType<T>::type Dot(const vtkm::Vec<T, 4>& a,
const vtkm::Vec<T, 4>& b)
static inline VTKM_EXEC_CONT auto Dot(const vtkm::Vec<T, 4>& a, const vtkm::Vec<T, 4>& b)
{
return (a[0] * b[0]) + (a[1] * b[1]) + (a[2] * b[2]) + (a[3] * b[3]);
}
// Integer types of a width less than an integer get implicitly casted to
// an integer when doing a multiplication.
#define VTK_M_SCALAR_DOT(stype) \
static inline VTKM_EXEC_CONT detail::DotType<stype>::type dot(stype a, stype b) \
{ \
return a * b; \
} /* LEGACY */ \
static inline VTKM_EXEC_CONT detail::DotType<stype>::type Dot(stype a, stype b) { return a * b; }
#define VTK_M_SCALAR_DOT(stype) \
static inline VTKM_EXEC_CONT auto dot(stype a, stype b) { return a * b; } /* LEGACY */ \
static inline VTKM_EXEC_CONT auto Dot(stype a, stype b) { return a * b; }
VTK_M_SCALAR_DOT(vtkm::Int8)
VTK_M_SCALAR_DOT(vtkm::UInt8)
VTK_M_SCALAR_DOT(vtkm::Int16)
@ -1515,20 +1494,17 @@ static inline VTKM_EXEC_CONT auto dot(const T& a, const T& b) -> decltype(detail
return vtkm::Dot(a, b);
}
template <typename T>
static inline VTKM_EXEC_CONT typename detail::DotType<T>::type dot(const vtkm::Vec<T, 2>& a,
const vtkm::Vec<T, 2>& b)
static inline VTKM_EXEC_CONT auto dot(const vtkm::Vec<T, 2>& a, const vtkm::Vec<T, 2>& b)
{
return vtkm::Dot(a, b);
}
template <typename T>
static inline VTKM_EXEC_CONT typename detail::DotType<T>::type dot(const vtkm::Vec<T, 3>& a,
const vtkm::Vec<T, 3>& b)
static inline VTKM_EXEC_CONT auto dot(const vtkm::Vec<T, 3>& a, const vtkm::Vec<T, 3>& b)
{
return vtkm::Dot(a, b);
}
template <typename T>
static inline VTKM_EXEC_CONT typename detail::DotType<T>::type dot(const vtkm::Vec<T, 4>& a,
const vtkm::Vec<T, 4>& b)
static inline VTKM_EXEC_CONT auto dot(const vtkm::Vec<T, 4>& a, const vtkm::Vec<T, 4>& b)
{
return vtkm::Dot(a, b);
}