From 1ca55ac3196894460325e805fcc5734ceaa3bf27 Mon Sep 17 00:00:00 2001 From: Kenneth Moreland Date: Sat, 9 Feb 2019 16:49:17 -0700 Subject: [PATCH] Add specialized operators for ArrayPortalValueReference The ArrayPortalValueReference is supposed to behave just like the value it encapsulates and does so by automatically converting to the base type when necessary. However, when it is possible to convert that to something else, it is possible to get errors about ambiguous overloads. To avoid these, add specialized versions of the operators to specify which ones should be used. Also consolidated the CUDA version of an ArrayPortalValueReference to the standard one. The two implementations were equivalent and we would like changes to apply to both. --- .../portal-value-reference-operators.md | 12 + vtkm/cont/internal/FunctorsGeneral.h | 2 +- .../cuda/internal/IteratorFromArrayPortal.h | 62 +- vtkm/exec/cuda/internal/WrappedOperators.h | 38 +- vtkm/internal/ArrayPortalValueReference.h | 659 +++++++++++++++++- .../UnitTestArrayPortalValueReference.cxx | 214 +++++- 6 files changed, 910 insertions(+), 77 deletions(-) create mode 100644 docs/changelog/portal-value-reference-operators.md diff --git a/docs/changelog/portal-value-reference-operators.md b/docs/changelog/portal-value-reference-operators.md new file mode 100644 index 000000000..f9b3ca13b --- /dev/null +++ b/docs/changelog/portal-value-reference-operators.md @@ -0,0 +1,12 @@ +# Added specialized operators for ArrayPortalValueReference + +The ArrayPortalValueReference is supposed to behave just like the value it +encapsulates and does so by automatically converting to the base type when +necessary. However, when it is possible to convert that to something else, +it is possible to get errors about ambiguous overloads. To avoid these, add +specialized versions of the operators to specify which ones should be used. + +Also consolidated the CUDA version of an ArrayPortalValueReference to the +standard one. The two implementations were equivalent and we would like +changes to apply to both. + diff --git a/vtkm/cont/internal/FunctorsGeneral.h b/vtkm/cont/internal/FunctorsGeneral.h index aedc709e5..bd3d00d1b 100644 --- a/vtkm/cont/internal/FunctorsGeneral.h +++ b/vtkm/cont/internal/FunctorsGeneral.h @@ -38,7 +38,7 @@ namespace internal // Binary function object wrapper which can detect and handle calling the // wrapped operator with complex value types such as -// IteratorFromArrayPortalValue which happen when passed an input array that +// ArrayPortalValueReference which happen when passed an input array that // is implicit. template struct WrappedBinaryOperator diff --git a/vtkm/exec/cuda/internal/IteratorFromArrayPortal.h b/vtkm/exec/cuda/internal/IteratorFromArrayPortal.h index bed83ec23..39c96e20c 100644 --- a/vtkm/exec/cuda/internal/IteratorFromArrayPortal.h +++ b/vtkm/exec/cuda/internal/IteratorFromArrayPortal.h @@ -22,6 +22,7 @@ #include #include +#include #include // Disable warnings we check vtkm for but Thrust does not. @@ -40,57 +41,13 @@ namespace cuda namespace internal { -template -struct PortalValue -{ - using ValueType = typename ArrayPortalType::ValueType; - - VTKM_EXEC_CONT - PortalValue(const ArrayPortalType& portal, vtkm::Id index) - : Portal(portal) - , Index(index) - { - } - - VTKM_EXEC - void Swap(PortalValue& rhs) throw() - { - //we need use the explicit type not a proxy temp object - //A proxy temp object would point to the same underlying data structure - //and would not hold the old value of *this once *this was set to rhs. - const ValueType aValue = *this; - *this = rhs; - rhs = aValue; - } - - VTKM_EXEC - PortalValue& operator=(const PortalValue& rhs) - { - this->Portal.Set(this->Index, rhs.Portal.Get(rhs.Index)); - return *this; - } - - VTKM_EXEC - ValueType operator=(const ValueType& value) const - { - this->Portal.Set(this->Index, value); - return value; - } - - VTKM_EXEC - operator ValueType(void) const { return this->Portal.Get(this->Index); } - - const ArrayPortalType& Portal; - vtkm::Id Index; -}; - template class IteratorFromArrayPortal : public ::thrust::iterator_facade, typename ArrayPortalType::ValueType, ::thrust::system::cuda::tag, ::thrust::random_access_traversal_tag, - PortalValue, + vtkm::internal::ArrayPortalValueReference, std::ptrdiff_t> { public: @@ -109,9 +66,11 @@ public: } VTKM_EXEC - PortalValue operator[](std::ptrdiff_t idx) const //NEEDS to be signed + vtkm::internal::ArrayPortalValueReference operator[]( + std::ptrdiff_t idx) const //NEEDS to be signed { - return PortalValue(this->Portal, this->Index + static_cast(idx)); + return vtkm::internal::ArrayPortalValueReference( + this->Portal, this->Index + static_cast(idx)); } private: @@ -122,9 +81,9 @@ private: friend class ::thrust::iterator_core_access; VTKM_EXEC - PortalValue dereference() const + vtkm::internal::ArrayPortalValueReference dereference() const { - return PortalValue(this->Portal, this->Index); + return vtkm::internal::ArrayPortalValueReference(this->Portal, this->Index); } VTKM_EXEC @@ -167,7 +126,8 @@ private: // //But for vtk-m we pass in facade objects, which are passed by value, but //must be treated as references. So do to do that properly we need to specialize -//is_non_const_reference to state a PortalValue by value is valid for writing +//is_non_const_reference to state an ArrayPortalValueReference by value is valid +//for writing namespace thrust { namespace detail @@ -177,7 +137,7 @@ template struct is_non_const_reference; template -struct is_non_const_reference> +struct is_non_const_reference> : thrust::detail::true_type { }; diff --git a/vtkm/exec/cuda/internal/WrappedOperators.h b/vtkm/exec/cuda/internal/WrappedOperators.h index 4a0166ce6..59e17d302 100644 --- a/vtkm/exec/cuda/internal/WrappedOperators.h +++ b/vtkm/exec/cuda/internal/WrappedOperators.h @@ -42,7 +42,7 @@ namespace internal // Unary function object wrapper which can detect and handle calling the // wrapped operator with complex value types such as -// PortalValue which happen when passed an input array that +// ArrayPortalValueReference which happen when passed an input array that // is implicit. template struct WrappedUnaryPredicate @@ -70,9 +70,9 @@ struct WrappedUnaryPredicate VTKM_EXEC bool operator()(const T& x) const { return m_f(x); } template - VTKM_EXEC bool operator()(const PortalValue& x) const + VTKM_EXEC bool operator()(const vtkm::internal::ArrayPortalValueReference& x) const { - return m_f((T)x); + return m_f(x.Get()); } VTKM_EXEC bool operator()(const T* x) const { return m_f(*x); } @@ -80,7 +80,7 @@ struct WrappedUnaryPredicate // Binary function object wrapper which can detect and handle calling the // wrapped operator with complex value types such as -// PortalValue which happen when passed an input array that +// ArrayPortalValueReference which happen when passed an input array that // is implicit. template struct WrappedBinaryOperator @@ -109,27 +109,24 @@ struct WrappedBinaryOperator VTKM_EXEC T operator()(const T& x, const T& y) const { return m_f(x, y); } template - VTKM_EXEC T operator()(const T& x, const PortalValue& y) const + VTKM_EXEC T operator()(const T& x, const vtkm::internal::ArrayPortalValueReference& y) const { // to support proper implicit conversion, and avoid overload // ambiguities. - T conv_y = y; - return m_f(x, conv_y); + return m_f(x, y.Get()); } template - VTKM_EXEC T operator()(const PortalValue& x, const T& y) const + VTKM_EXEC T operator()(const vtkm::internal::ArrayPortalValueReference& x, const T& y) const { - T conv_x = x; - return m_f(conv_x, y); + return m_f(x.Get(), y); } template - VTKM_EXEC T operator()(const PortalValue& x, const PortalValue& y) const + VTKM_EXEC T operator()(const vtkm::internal::ArrayPortalValueReference& x, + const vtkm::internal::ArrayPortalValueReference& y) const { - T conv_x = x; - T conv_y = y; - return m_f(conv_x, conv_y); + return m_f(x.Get(), y.Get()); } VTKM_EXEC T operator()(const T* const x, const T& y) const { return m_f(*x, y); } @@ -166,21 +163,22 @@ struct WrappedBinaryPredicate VTKM_EXEC bool operator()(const T& x, const T& y) const { return m_f(x, y); } template - VTKM_EXEC bool operator()(const T& x, const PortalValue& y) const + VTKM_EXEC bool operator()(const T& x, const vtkm::internal::ArrayPortalValueReference& y) const { - return m_f(x, (T)y); + return m_f(x, y.Get()); } template - VTKM_EXEC bool operator()(const PortalValue& x, const T& y) const + VTKM_EXEC bool operator()(const vtkm::internal::ArrayPortalValueReference& x, const T& y) const { - return m_f((T)x, y); + return m_f(x.Get(), y); } template - VTKM_EXEC bool operator()(const PortalValue& x, const PortalValue& y) const + VTKM_EXEC bool operator()(const vtkm::internal::ArrayPortalValueReference& x, + const vtkm::internal::ArrayPortalValueReference& y) const { - return m_f((T)x, (T)y); + return m_f(x.Get(), y.Get()); } VTKM_EXEC bool operator()(const T* const x, const T& y) const { return m_f(*x, y); } diff --git a/vtkm/internal/ArrayPortalValueReference.h b/vtkm/internal/ArrayPortalValueReference.h index baeb0901c..98ca8c4d4 100644 --- a/vtkm/internal/ArrayPortalValueReference.h +++ b/vtkm/internal/ArrayPortalValueReference.h @@ -57,6 +57,18 @@ struct ArrayPortalValueReference { } + VTKM_SUPPRESS_EXEC_WARNINGS + VTKM_EXEC_CONT + ValueType Get() const { return this->Portal.Get(this->Index); } + + VTKM_SUPPRESS_EXEC_WARNINGS + VTKM_EXEC_CONT + void Set(ValueType&& value) { this->Portal.Set(this->Index, std::move(value)); } + + VTKM_SUPPRESS_EXEC_WARNINGS + VTKM_EXEC_CONT + void Set(const ValueType& value) { this->Portal.Set(this->Index, value); } + VTKM_CONT void Swap(ArrayPortalValueReference& rhs) throw() { @@ -73,7 +85,7 @@ struct ArrayPortalValueReference ArrayPortalValueReference& operator=( const ArrayPortalValueReference& rhs) { - this->Portal.Set(this->Index, rhs.Portal.Get(rhs.Index)); + this->Set(rhs.Portal.Get(rhs.Index)); return *this; } @@ -81,7 +93,7 @@ struct ArrayPortalValueReference VTKM_EXEC_CONT ValueType operator=(const ValueType& value) { - this->Portal.Set(this->Index, value); + this->Set(value); return value; } @@ -89,6 +101,178 @@ struct ArrayPortalValueReference VTKM_EXEC_CONT operator ValueType(void) const { return this->Portal.Get(this->Index); } + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator+=(const T& rhs) + { + ValueType lhs = this->Get(); + lhs += rhs; + this->Set(lhs); + return lhs; + } + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator+=(const ArrayPortalValueReference& rhs) + { + ValueType lhs = this->Get(); + lhs += rhs.Get(); + this->Set(lhs); + return lhs; + } + + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator-=(const T& rhs) + { + ValueType lhs = this->Get(); + lhs -= rhs; + this->Set(lhs); + return lhs; + } + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator-=(const ArrayPortalValueReference& rhs) + { + ValueType lhs = this->Get(); + lhs -= rhs.Get(); + this->Set(lhs); + return lhs; + } + + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator*=(const T& rhs) + { + ValueType lhs = this->Get(); + lhs *= rhs; + this->Set(lhs); + return lhs; + } + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator*=(const ArrayPortalValueReference& rhs) + { + ValueType lhs = this->Get(); + lhs *= rhs.Get(); + this->Set(lhs); + return lhs; + } + + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator/=(const T& rhs) + { + ValueType lhs = this->Get(); + lhs /= rhs; + this->Set(lhs); + return lhs; + } + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator/=(const ArrayPortalValueReference& rhs) + { + ValueType lhs = this->Get(); + lhs /= rhs.Get(); + this->Set(lhs); + return lhs; + } + + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator&=(const T& rhs) + { + ValueType lhs = this->Get(); + lhs &= rhs; + this->Set(lhs); + return lhs; + } + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator&=(const ArrayPortalValueReference& rhs) + { + ValueType lhs = this->Get(); + lhs &= rhs.Get(); + this->Set(lhs); + return lhs; + } + + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator|=(const T& rhs) + { + ValueType lhs = this->Get(); + lhs |= rhs; + this->Set(lhs); + return lhs; + } + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator|=(const ArrayPortalValueReference& rhs) + { + ValueType lhs = this->Get(); + lhs |= rhs.Get(); + this->Set(lhs); + return lhs; + } + + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator^=(const T& rhs) + { + ValueType lhs = this->Get(); + lhs ^= rhs; + this->Set(lhs); + return lhs; + } + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator^=(const ArrayPortalValueReference& rhs) + { + ValueType lhs = this->Get(); + lhs ^= rhs.Get(); + this->Set(lhs); + return lhs; + } + + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator>>=(const T& rhs) + { + ValueType lhs = this->Get(); + lhs >>= rhs; + this->Set(lhs); + return lhs; + } + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator>>=(const ArrayPortalValueReference& rhs) + { + ValueType lhs = this->Get(); + lhs >>= rhs.Get(); + this->Set(lhs); + return lhs; + } + + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator<<=(const T& rhs) + { + ValueType lhs = this->Get(); + lhs <<= rhs; + this->Set(lhs); + return lhs; + } + VTKM_SUPPRESS_EXEC_WARNINGS + template + VTKM_EXEC_CONT ValueType operator<<=(const ArrayPortalValueReference& rhs) + { + ValueType lhs = this->Get(); + lhs <<= rhs.Get(); + this->Set(lhs); + return lhs; + } + +private: const ArrayPortalType& Portal; vtkm::Id Index; }; @@ -121,6 +305,477 @@ void swap(typename vtkm::internal::ArrayPortalValueReference::ValueType& a, b = a; a = tmp; } + +// The reason why all the operators on ArrayPortalValueReference are defined outside of the class +// is so that in the case that the operator in question is not defined in the value type, these +// operators will not be instantiated (and therefore cause a compile error) unless they are +// directly used (in which case a compile error is appropriate). + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator==(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() == rhs) +{ + return lhs.Get() == rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator==(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() == rhs.Get()) +{ + return lhs.Get() == rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator==(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs == rhs.Get()) +{ + return lhs == rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator!=(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() != rhs) +{ + return lhs.Get() != rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator!=(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() != rhs.Get()) +{ + return lhs.Get() != rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator!=(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs != rhs.Get()) +{ + return lhs != rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator<(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() < rhs) +{ + return lhs.Get() < rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator<(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() < rhs.Get()) +{ + return lhs.Get() < rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator<(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs < rhs.Get()) +{ + return lhs < rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator>(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() > rhs) +{ + return lhs.Get() > rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator>(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() > rhs.Get()) +{ + return lhs.Get() > rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator>(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs > rhs.Get()) +{ + return lhs > rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator<=(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() <= rhs) +{ + return lhs.Get() <= rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator<=(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() <= rhs.Get()) +{ + return lhs.Get() <= rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator<=(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs <= rhs.Get()) +{ + return lhs <= rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator>=(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() >= rhs) +{ + return lhs.Get() >= rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator>=(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() >= rhs.Get()) +{ + return lhs.Get() >= rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator>=(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs >= rhs.Get()) +{ + return lhs >= rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator+(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() + rhs) +{ + return lhs.Get() + rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator+(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() + rhs.Get()) +{ + return lhs.Get() + rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator+(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs + rhs.Get()) +{ + return lhs + rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator-(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() - rhs) +{ + return lhs.Get() - rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator-(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() - rhs.Get()) +{ + return lhs.Get() - rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator-(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs - rhs.Get()) +{ + return lhs - rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator*(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() * rhs) +{ + return lhs.Get() * rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator*(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() * rhs.Get()) +{ + return lhs.Get() * rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator*(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs * rhs.Get()) +{ + return lhs * rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator/(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() / rhs) +{ + return lhs.Get() / rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator/(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() / rhs.Get()) +{ + return lhs.Get() / rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator/(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs / rhs.Get()) +{ + return lhs / rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator%(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() % rhs) +{ + return lhs.Get() % rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator%(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() % rhs.Get()) +{ + return lhs.Get() % rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator%(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs % rhs.Get()) +{ + return lhs % rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator^(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() ^ rhs) +{ + return lhs.Get() ^ rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator^(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() ^ rhs.Get()) +{ + return lhs.Get() ^ rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator^(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs ^ rhs.Get()) +{ + return lhs ^ rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator|(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() | rhs) +{ + return lhs.Get() | rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator|(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() | rhs.Get()) +{ + return lhs.Get() | rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator|(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs | rhs.Get()) +{ + return lhs | rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator&(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() & rhs) +{ + return lhs.Get() & rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator&(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() & rhs.Get()) +{ + return lhs.Get() & rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator&(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs & rhs.Get()) +{ + return lhs & rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator<<(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() << rhs) +{ + return lhs.Get() << rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator<<(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() << rhs.Get()) +{ + return lhs.Get() << rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator<<(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs << rhs.Get()) +{ + return lhs << rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator>>(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() >> rhs) +{ + return lhs.Get() >> rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator>>(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() >> rhs.Get()) +{ + return lhs.Get() >> rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator>>(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs >> rhs.Get()) +{ + return lhs >> rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator~(const ArrayPortalValueReference& ref) + -> decltype(~ref.Get()) +{ + return ~ref.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator!(const ArrayPortalValueReference& ref) + -> decltype(!ref.Get()) +{ + return !ref.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator&&(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() && rhs) +{ + return lhs.Get() && rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator&&(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() && rhs.Get()) +{ + return lhs.Get() && rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator&&(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs && rhs.Get()) +{ + return lhs && rhs.Get(); +} + +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator||(const ArrayPortalValueReference& lhs, + const typename LhsPortalType::ValueType& rhs) + -> decltype(lhs.Get() || rhs) +{ + return lhs.Get() || rhs; +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator||(const ArrayPortalValueReference& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs.Get() || rhs.Get()) +{ + return lhs.Get() || rhs.Get(); +} +VTKM_SUPPRESS_EXEC_WARNINGS +template +VTKM_EXEC_CONT auto operator||(const typename RhsPortalType::ValueType& lhs, + const ArrayPortalValueReference& rhs) + -> decltype(lhs || rhs.Get()) +{ + return lhs || rhs.Get(); +} } } // namespace vtkm::internal diff --git a/vtkm/internal/testing/UnitTestArrayPortalValueReference.cxx b/vtkm/internal/testing/UnitTestArrayPortalValueReference.cxx index 28366d554..81c0f37a4 100644 --- a/vtkm/internal/testing/UnitTestArrayPortalValueReference.cxx +++ b/vtkm/internal/testing/UnitTestArrayPortalValueReference.cxx @@ -22,11 +22,15 @@ #include +#include + #include namespace { +static constexpr vtkm::Id ARRAY_SIZE = 10; + template void SetReference(vtkm::Id index, vtkm::internal::ArrayPortalValueReference ref) { @@ -41,7 +45,204 @@ void CheckReference(vtkm::Id index, vtkm::internal::ArrayPortalValueReference +void TryOperatorsNoVec(vtkm::Id index, + vtkm::internal::ArrayPortalValueReference ref, + vtkm::TypeTraitsScalarTag) +{ + using ValueType = typename ArrayPortalType::ValueType; + + ValueType expected = TestValue(index, ValueType()); + VTKM_TEST_ASSERT(ref.Get() == expected, "Reference did not start out as expected."); + + VTKM_TEST_ASSERT(!(ref < ref)); + VTKM_TEST_ASSERT(ref < ValueType(expected + ValueType(1))); + VTKM_TEST_ASSERT(ValueType(expected - ValueType(1)) < ref); + + VTKM_TEST_ASSERT(!(ref > ref)); + VTKM_TEST_ASSERT(ref > ValueType(expected - ValueType(1))); + VTKM_TEST_ASSERT(ValueType(expected + ValueType(1)) > ref); + + VTKM_TEST_ASSERT(ref <= ref); + VTKM_TEST_ASSERT(ref <= ValueType(expected + ValueType(1))); + VTKM_TEST_ASSERT(ValueType(expected - ValueType(1)) <= ref); + + VTKM_TEST_ASSERT(ref >= ref); + VTKM_TEST_ASSERT(ref >= ValueType(expected - ValueType(1))); + VTKM_TEST_ASSERT(ValueType(expected + ValueType(1)) >= ref); +} + +template +void TryOperatorsNoVec(vtkm::Id, + vtkm::internal::ArrayPortalValueReference, + vtkm::TypeTraitsVectorTag) +{ +} + +template +void TryOperatorsInt(vtkm::Id index, + vtkm::internal::ArrayPortalValueReference ref, + vtkm::TypeTraitsScalarTag, + vtkm::TypeTraitsIntegerTag) +{ + using ValueType = typename ArrayPortalType::ValueType; + + const ValueType operand = TestValue(ARRAY_SIZE, ValueType()); + ValueType expected = TestValue(index, ValueType()); + VTKM_TEST_ASSERT(ref.Get() == expected, "Reference did not start out as expected."); + + VTKM_TEST_ASSERT((ref % ref) == (expected % expected)); + VTKM_TEST_ASSERT((ref % expected) == (expected % expected)); + VTKM_TEST_ASSERT((expected % ref) == (expected % expected)); + + VTKM_TEST_ASSERT((ref ^ ref) == (expected ^ expected)); + VTKM_TEST_ASSERT((ref ^ expected) == (expected ^ expected)); + VTKM_TEST_ASSERT((expected ^ ref) == (expected ^ expected)); + + VTKM_TEST_ASSERT((ref | ref) == (expected | expected)); + VTKM_TEST_ASSERT((ref | expected) == (expected | expected)); + VTKM_TEST_ASSERT((expected | ref) == (expected | expected)); + + VTKM_TEST_ASSERT((ref & ref) == (expected & expected)); + VTKM_TEST_ASSERT((ref & expected) == (expected & expected)); + VTKM_TEST_ASSERT((expected & ref) == (expected & expected)); + + VTKM_TEST_ASSERT((ref << ref) == (expected << expected)); + VTKM_TEST_ASSERT((ref << expected) == (expected << expected)); + VTKM_TEST_ASSERT((expected << ref) == (expected << expected)); + + VTKM_TEST_ASSERT((ref << ref) == (expected << expected)); + VTKM_TEST_ASSERT((ref << expected) == (expected << expected)); + VTKM_TEST_ASSERT((expected << ref) == (expected << expected)); + + VTKM_TEST_ASSERT(~ref == ~expected); + + VTKM_TEST_ASSERT(!(!ref)); + + VTKM_TEST_ASSERT(ref && ref); + VTKM_TEST_ASSERT(ref && expected); + VTKM_TEST_ASSERT(expected && ref); + + VTKM_TEST_ASSERT(ref || ref); + VTKM_TEST_ASSERT(ref || expected); + VTKM_TEST_ASSERT(expected || ref); + + ref &= ref; + expected &= expected; + VTKM_TEST_ASSERT(ref == expected); + ref &= operand; + expected &= operand; + VTKM_TEST_ASSERT(ref == expected); + + ref |= ref; + expected |= expected; + VTKM_TEST_ASSERT(ref == expected); + ref |= operand; + expected |= operand; + VTKM_TEST_ASSERT(ref == expected); + + ref >>= ref; + expected >>= expected; + VTKM_TEST_ASSERT(ref == expected); + ref >>= operand; + expected >>= operand; + VTKM_TEST_ASSERT(ref == expected); + + ref <<= ref; + expected <<= expected; + VTKM_TEST_ASSERT(ref == expected); + ref <<= operand; + expected <<= operand; + VTKM_TEST_ASSERT(ref == expected); + + ref ^= ref; + expected ^= expected; + VTKM_TEST_ASSERT(ref == expected); + ref ^= operand; + expected ^= operand; + VTKM_TEST_ASSERT(ref == expected); +} + +template +void TryOperatorsInt(vtkm::Id, + vtkm::internal::ArrayPortalValueReference, + DimTag, + NumericTag) +{ +} + +template +void TryOperators(vtkm::Id index, vtkm::internal::ArrayPortalValueReference ref) +{ + using ValueType = typename ArrayPortalType::ValueType; + + const ValueType operand = TestValue(ARRAY_SIZE, ValueType()); + ValueType expected = TestValue(index, ValueType()); + VTKM_TEST_ASSERT(ref.Get() == expected, "Reference did not start out as expected."); + + // Test comparison operators. + VTKM_TEST_ASSERT(ref == ref); + VTKM_TEST_ASSERT(ref == expected); + VTKM_TEST_ASSERT(expected == ref); + + VTKM_TEST_ASSERT(!(ref != ref)); + VTKM_TEST_ASSERT(!(ref != expected)); + VTKM_TEST_ASSERT(!(expected != ref)); + + TryOperatorsNoVec(index, ref, typename vtkm::TypeTraits::DimensionalityTag()); + + VTKM_TEST_ASSERT((ref + ref) == (expected + expected)); + VTKM_TEST_ASSERT((ref + expected) == (expected + expected)); + VTKM_TEST_ASSERT((expected + ref) == (expected + expected)); + + VTKM_TEST_ASSERT((ref - ref) == (expected - expected)); + VTKM_TEST_ASSERT((ref - expected) == (expected - expected)); + VTKM_TEST_ASSERT((expected - ref) == (expected - expected)); + + VTKM_TEST_ASSERT((ref * ref) == (expected * expected)); + VTKM_TEST_ASSERT((ref * expected) == (expected * expected)); + VTKM_TEST_ASSERT((expected * ref) == (expected * expected)); + + VTKM_TEST_ASSERT((ref / ref) == (expected / expected)); + VTKM_TEST_ASSERT((ref / expected) == (expected / expected)); + VTKM_TEST_ASSERT((expected / ref) == (expected / expected)); + + ref += ref; + expected += expected; + VTKM_TEST_ASSERT(ref == expected); + ref += operand; + expected += operand; + VTKM_TEST_ASSERT(ref == expected); + + ref -= ref; + expected -= expected; + VTKM_TEST_ASSERT(ref == expected); + ref -= operand; + expected -= operand; + VTKM_TEST_ASSERT(ref == expected); + + ref *= ref; + expected *= expected; + VTKM_TEST_ASSERT(ref == expected); + ref *= operand; + expected *= operand; + VTKM_TEST_ASSERT(ref == expected); + + ref /= ref; + expected /= expected; + VTKM_TEST_ASSERT(ref == expected); + ref /= operand; + expected /= operand; + VTKM_TEST_ASSERT(ref == expected); + + // Reset ref + ref = TestValue(index, ValueType()); + + TryOperatorsInt(index, + ref, + typename vtkm::TypeTraits::DimensionalityTag(), + typename vtkm::TypeTraits::NumericTag()); +} struct DoTestForType { @@ -54,7 +255,7 @@ struct DoTestForType std::cout << "Set array using reference" << std::endl; using PortalType = typename vtkm::cont::ArrayHandle::PortalControl; PortalType portal = array.GetPortalControl(); - for (vtkm::Id index = 0; index < ARRAY_SIZE; index++) + for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index) { SetReference(index, vtkm::internal::ArrayPortalValueReference(portal, index)); } @@ -63,10 +264,17 @@ struct DoTestForType CheckPortal(portal); std::cout << "Check references in set array." << std::endl; - for (vtkm::Id index = 0; index < ARRAY_SIZE; index++) + for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index) { CheckReference(index, vtkm::internal::ArrayPortalValueReference(portal, index)); } + + std::cout << "Check that operators work." << std::endl; + // Start at 1 to avoid issues with 0. + for (vtkm::Id index = 1; index < ARRAY_SIZE; ++index) + { + TryOperators(index, vtkm::internal::ArrayPortalValueReference(portal, index)); + } } };