Allow VecBaseCommon operators to work with any Vec-like

Previously, the arithmetic assignment operators (`+=`, `-=`, `*=`, `/=`)
on `VecBaseCommon` only accepted another subclass of `VecBaseCommon`.
Changed the operators to accept any type.

The benefit of this change is that the assignment operator of classes
that inherit from `VecBaseCommon` can now work with `Vec`-like classes
that do not inherit from `VecBaseCommon`. I think the only real downside
is that the template could match other classes that would lead to
invalid comparisons. But such use would lead to a compile error anyway.
The difference is that instead of getting an error that no overloaded
version of the operator is available, you will get an error inside the
code of the operator.
This commit is contained in:
Kenneth Moreland 2021-01-05 17:24:14 -07:00
parent 74536d4ca1
commit 1cc6dbb0c2

@ -343,15 +343,14 @@ public:
}
}
template <typename OtherComponentType, typename OtherVecType>
VTKM_EXEC_CONT DerivedClass& operator=(
const vtkm::detail::VecBaseCommon<OtherComponentType, OtherVecType>& src)
// Only works with Vec-like objects with operator[] and GetNumberOfComponents().
template <typename OtherVecType>
VTKM_EXEC_CONT DerivedClass& operator=(const OtherVecType& src)
{
const OtherVecType& srcDerived = static_cast<const OtherVecType&>(src);
VTKM_ASSERT(this->NumComponents() == srcDerived.GetNumberOfComponents());
VTKM_ASSERT(this->NumComponents() == src.GetNumberOfComponents());
for (vtkm::IdComponent i = 0; i < this->NumComponents(); ++i)
{
this->Component(i) = OtherComponentType(srcDerived[i]);
this->Component(i) = src[i];
}
return this->Derived();
}
@ -413,14 +412,12 @@ public:
}
template <typename OtherClass>
inline VTKM_EXEC_CONT DerivedClass& operator+=(
const VecBaseCommon<ComponentType, OtherClass>& other)
inline VTKM_EXEC_CONT DerivedClass& operator+=(const OtherClass& other)
{
const OtherClass& other_derived = static_cast<const OtherClass&>(other);
VTKM_ASSERT(this->NumComponents() == other_derived.GetNumberOfComponents());
VTKM_ASSERT(this->NumComponents() == other.GetNumberOfComponents());
for (vtkm::IdComponent i = 0; i < this->NumComponents(); ++i)
{
this->Component(i) += other_derived[i];
this->Component(i) += other[i];
}
return this->Derived();
}
@ -439,14 +436,12 @@ public:
}
template <typename OtherClass>
inline VTKM_EXEC_CONT DerivedClass& operator-=(
const VecBaseCommon<ComponentType, OtherClass>& other)
inline VTKM_EXEC_CONT DerivedClass& operator-=(const OtherClass& other)
{
const OtherClass& other_derived = static_cast<const OtherClass&>(other);
VTKM_ASSERT(this->NumComponents() == other_derived.GetNumberOfComponents());
VTKM_ASSERT(this->NumComponents() == other.GetNumberOfComponents());
for (vtkm::IdComponent i = 0; i < this->NumComponents(); ++i)
{
this->Component(i) -= other_derived[i];
this->Component(i) -= other[i];
}
return this->Derived();
}
@ -464,14 +459,12 @@ public:
}
template <typename OtherClass>
inline VTKM_EXEC_CONT DerivedClass& operator*=(
const VecBaseCommon<ComponentType, OtherClass>& other)
inline VTKM_EXEC_CONT DerivedClass& operator*=(const OtherClass& other)
{
const OtherClass& other_derived = static_cast<const OtherClass&>(other);
VTKM_ASSERT(this->NumComponents() == other_derived.GetNumberOfComponents());
VTKM_ASSERT(this->NumComponents() == other.GetNumberOfComponents());
for (vtkm::IdComponent i = 0; i < this->NumComponents(); ++i)
{
this->Component(i) *= other_derived[i];
this->Component(i) *= other[i];
}
return this->Derived();
}
@ -489,13 +482,12 @@ public:
}
template <typename OtherClass>
VTKM_EXEC_CONT DerivedClass& operator/=(const VecBaseCommon<ComponentType, OtherClass>& other)
VTKM_EXEC_CONT DerivedClass& operator/=(const OtherClass& other)
{
const OtherClass& other_derived = static_cast<const OtherClass&>(other);
VTKM_ASSERT(this->NumComponents() == other_derived.GetNumberOfComponents());
VTKM_ASSERT(this->NumComponents() == other.GetNumberOfComponents());
for (vtkm::IdComponent i = 0; i < this->NumComponents(); ++i)
{
this->Component(i) /= other_derived[i];
this->Component(i) /= other[i];
}
return this->Derived();
}