Add specialized operators for ArrayPortalValueReference

The ArrayPortalValueReference is supposed to behave just like the value
it encapsulates and does so by automatically converting to the base type
when necessary. However, when it is possible to convert that to
something else, it is possible to get errors about ambiguous overloads.
To avoid these, add specialized versions of the operators to specify
which ones should be used.

Also consolidated the CUDA version of an ArrayPortalValueReference to the
standard one. The two implementations were equivalent and we would like
changes to apply to both.
This commit is contained in:
Kenneth Moreland 2019-02-09 16:49:17 -07:00
parent 6851077ebb
commit 1ca55ac319
6 changed files with 910 additions and 77 deletions

@ -0,0 +1,12 @@
# Added specialized operators for ArrayPortalValueReference
The ArrayPortalValueReference is supposed to behave just like the value it
encapsulates and does so by automatically converting to the base type when
necessary. However, when it is possible to convert that to something else,
it is possible to get errors about ambiguous overloads. To avoid these, add
specialized versions of the operators to specify which ones should be used.
Also consolidated the CUDA version of an ArrayPortalValueReference to the
standard one. The two implementations were equivalent and we would like
changes to apply to both.

@ -38,7 +38,7 @@ namespace internal
// Binary function object wrapper which can detect and handle calling the
// wrapped operator with complex value types such as
// IteratorFromArrayPortalValue which happen when passed an input array that
// ArrayPortalValueReference which happen when passed an input array that
// is implicit.
template <typename ResultType, typename Function>
struct WrappedBinaryOperator

@ -22,6 +22,7 @@
#include <vtkm/Pair.h>
#include <vtkm/Types.h>
#include <vtkm/internal/ArrayPortalValueReference.h>
#include <vtkm/internal/ExportMacros.h>
// Disable warnings we check vtkm for but Thrust does not.
@ -40,57 +41,13 @@ namespace cuda
namespace internal
{
template <class ArrayPortalType>
struct PortalValue
{
using ValueType = typename ArrayPortalType::ValueType;
VTKM_EXEC_CONT
PortalValue(const ArrayPortalType& portal, vtkm::Id index)
: Portal(portal)
, Index(index)
{
}
VTKM_EXEC
void Swap(PortalValue<ArrayPortalType>& rhs) throw()
{
//we need use the explicit type not a proxy temp object
//A proxy temp object would point to the same underlying data structure
//and would not hold the old value of *this once *this was set to rhs.
const ValueType aValue = *this;
*this = rhs;
rhs = aValue;
}
VTKM_EXEC
PortalValue<ArrayPortalType>& operator=(const PortalValue<ArrayPortalType>& rhs)
{
this->Portal.Set(this->Index, rhs.Portal.Get(rhs.Index));
return *this;
}
VTKM_EXEC
ValueType operator=(const ValueType& value) const
{
this->Portal.Set(this->Index, value);
return value;
}
VTKM_EXEC
operator ValueType(void) const { return this->Portal.Get(this->Index); }
const ArrayPortalType& Portal;
vtkm::Id Index;
};
template <class ArrayPortalType>
class IteratorFromArrayPortal
: public ::thrust::iterator_facade<IteratorFromArrayPortal<ArrayPortalType>,
typename ArrayPortalType::ValueType,
::thrust::system::cuda::tag,
::thrust::random_access_traversal_tag,
PortalValue<ArrayPortalType>,
vtkm::internal::ArrayPortalValueReference<ArrayPortalType>,
std::ptrdiff_t>
{
public:
@ -109,9 +66,11 @@ public:
}
VTKM_EXEC
PortalValue<ArrayPortalType> operator[](std::ptrdiff_t idx) const //NEEDS to be signed
vtkm::internal::ArrayPortalValueReference<ArrayPortalType> operator[](
std::ptrdiff_t idx) const //NEEDS to be signed
{
return PortalValue<ArrayPortalType>(this->Portal, this->Index + static_cast<vtkm::Id>(idx));
return vtkm::internal::ArrayPortalValueReference<ArrayPortalType>(
this->Portal, this->Index + static_cast<vtkm::Id>(idx));
}
private:
@ -122,9 +81,9 @@ private:
friend class ::thrust::iterator_core_access;
VTKM_EXEC
PortalValue<ArrayPortalType> dereference() const
vtkm::internal::ArrayPortalValueReference<ArrayPortalType> dereference() const
{
return PortalValue<ArrayPortalType>(this->Portal, this->Index);
return vtkm::internal::ArrayPortalValueReference<ArrayPortalType>(this->Portal, this->Index);
}
VTKM_EXEC
@ -167,7 +126,8 @@ private:
//
//But for vtk-m we pass in facade objects, which are passed by value, but
//must be treated as references. So do to do that properly we need to specialize
//is_non_const_reference to state a PortalValue by value is valid for writing
//is_non_const_reference to state an ArrayPortalValueReference by value is valid
//for writing
namespace thrust
{
namespace detail
@ -177,7 +137,7 @@ template <typename T>
struct is_non_const_reference;
template <typename T>
struct is_non_const_reference<vtkm::exec::cuda::internal::PortalValue<T>>
struct is_non_const_reference<vtkm::internal::ArrayPortalValueReference<T>>
: thrust::detail::true_type
{
};

@ -42,7 +42,7 @@ namespace internal
// Unary function object wrapper which can detect and handle calling the
// wrapped operator with complex value types such as
// PortalValue which happen when passed an input array that
// ArrayPortalValueReference which happen when passed an input array that
// is implicit.
template <typename T_, typename Function>
struct WrappedUnaryPredicate
@ -70,9 +70,9 @@ struct WrappedUnaryPredicate
VTKM_EXEC bool operator()(const T& x) const { return m_f(x); }
template <typename U>
VTKM_EXEC bool operator()(const PortalValue<U>& x) const
VTKM_EXEC bool operator()(const vtkm::internal::ArrayPortalValueReference<U>& x) const
{
return m_f((T)x);
return m_f(x.Get());
}
VTKM_EXEC bool operator()(const T* x) const { return m_f(*x); }
@ -80,7 +80,7 @@ struct WrappedUnaryPredicate
// Binary function object wrapper which can detect and handle calling the
// wrapped operator with complex value types such as
// PortalValue which happen when passed an input array that
// ArrayPortalValueReference which happen when passed an input array that
// is implicit.
template <typename T_, typename Function>
struct WrappedBinaryOperator
@ -109,27 +109,24 @@ struct WrappedBinaryOperator
VTKM_EXEC T operator()(const T& x, const T& y) const { return m_f(x, y); }
template <typename U>
VTKM_EXEC T operator()(const T& x, const PortalValue<U>& y) const
VTKM_EXEC T operator()(const T& x, const vtkm::internal::ArrayPortalValueReference<U>& y) const
{
// to support proper implicit conversion, and avoid overload
// ambiguities.
T conv_y = y;
return m_f(x, conv_y);
return m_f(x, y.Get());
}
template <typename U>
VTKM_EXEC T operator()(const PortalValue<U>& x, const T& y) const
VTKM_EXEC T operator()(const vtkm::internal::ArrayPortalValueReference<U>& x, const T& y) const
{
T conv_x = x;
return m_f(conv_x, y);
return m_f(x.Get(), y);
}
template <typename U, typename V>
VTKM_EXEC T operator()(const PortalValue<U>& x, const PortalValue<V>& y) const
VTKM_EXEC T operator()(const vtkm::internal::ArrayPortalValueReference<U>& x,
const vtkm::internal::ArrayPortalValueReference<V>& y) const
{
T conv_x = x;
T conv_y = y;
return m_f(conv_x, conv_y);
return m_f(x.Get(), y.Get());
}
VTKM_EXEC T operator()(const T* const x, const T& y) const { return m_f(*x, y); }
@ -166,21 +163,22 @@ struct WrappedBinaryPredicate
VTKM_EXEC bool operator()(const T& x, const T& y) const { return m_f(x, y); }
template <typename U>
VTKM_EXEC bool operator()(const T& x, const PortalValue<U>& y) const
VTKM_EXEC bool operator()(const T& x, const vtkm::internal::ArrayPortalValueReference<U>& y) const
{
return m_f(x, (T)y);
return m_f(x, y.Get());
}
template <typename U>
VTKM_EXEC bool operator()(const PortalValue<U>& x, const T& y) const
VTKM_EXEC bool operator()(const vtkm::internal::ArrayPortalValueReference<U>& x, const T& y) const
{
return m_f((T)x, y);
return m_f(x.Get(), y);
}
template <typename U, typename V>
VTKM_EXEC bool operator()(const PortalValue<U>& x, const PortalValue<V>& y) const
VTKM_EXEC bool operator()(const vtkm::internal::ArrayPortalValueReference<U>& x,
const vtkm::internal::ArrayPortalValueReference<V>& y) const
{
return m_f((T)x, (T)y);
return m_f(x.Get(), y.Get());
}
VTKM_EXEC bool operator()(const T* const x, const T& y) const { return m_f(*x, y); }

@ -57,6 +57,18 @@ struct ArrayPortalValueReference
{
}
VTKM_SUPPRESS_EXEC_WARNINGS
VTKM_EXEC_CONT
ValueType Get() const { return this->Portal.Get(this->Index); }
VTKM_SUPPRESS_EXEC_WARNINGS
VTKM_EXEC_CONT
void Set(ValueType&& value) { this->Portal.Set(this->Index, std::move(value)); }
VTKM_SUPPRESS_EXEC_WARNINGS
VTKM_EXEC_CONT
void Set(const ValueType& value) { this->Portal.Set(this->Index, value); }
VTKM_CONT
void Swap(ArrayPortalValueReference<ArrayPortalType>& rhs) throw()
{
@ -73,7 +85,7 @@ struct ArrayPortalValueReference
ArrayPortalValueReference<ArrayPortalType>& operator=(
const ArrayPortalValueReference<ArrayPortalType>& rhs)
{
this->Portal.Set(this->Index, rhs.Portal.Get(rhs.Index));
this->Set(rhs.Portal.Get(rhs.Index));
return *this;
}
@ -81,7 +93,7 @@ struct ArrayPortalValueReference
VTKM_EXEC_CONT
ValueType operator=(const ValueType& value)
{
this->Portal.Set(this->Index, value);
this->Set(value);
return value;
}
@ -89,6 +101,178 @@ struct ArrayPortalValueReference
VTKM_EXEC_CONT
operator ValueType(void) const { return this->Portal.Get(this->Index); }
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator+=(const T& rhs)
{
ValueType lhs = this->Get();
lhs += rhs;
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator+=(const ArrayPortalValueReference<T>& rhs)
{
ValueType lhs = this->Get();
lhs += rhs.Get();
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator-=(const T& rhs)
{
ValueType lhs = this->Get();
lhs -= rhs;
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator-=(const ArrayPortalValueReference<T>& rhs)
{
ValueType lhs = this->Get();
lhs -= rhs.Get();
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator*=(const T& rhs)
{
ValueType lhs = this->Get();
lhs *= rhs;
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator*=(const ArrayPortalValueReference<T>& rhs)
{
ValueType lhs = this->Get();
lhs *= rhs.Get();
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator/=(const T& rhs)
{
ValueType lhs = this->Get();
lhs /= rhs;
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator/=(const ArrayPortalValueReference<T>& rhs)
{
ValueType lhs = this->Get();
lhs /= rhs.Get();
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator&=(const T& rhs)
{
ValueType lhs = this->Get();
lhs &= rhs;
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator&=(const ArrayPortalValueReference<T>& rhs)
{
ValueType lhs = this->Get();
lhs &= rhs.Get();
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator|=(const T& rhs)
{
ValueType lhs = this->Get();
lhs |= rhs;
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator|=(const ArrayPortalValueReference<T>& rhs)
{
ValueType lhs = this->Get();
lhs |= rhs.Get();
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator^=(const T& rhs)
{
ValueType lhs = this->Get();
lhs ^= rhs;
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator^=(const ArrayPortalValueReference<T>& rhs)
{
ValueType lhs = this->Get();
lhs ^= rhs.Get();
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator>>=(const T& rhs)
{
ValueType lhs = this->Get();
lhs >>= rhs;
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator>>=(const ArrayPortalValueReference<T>& rhs)
{
ValueType lhs = this->Get();
lhs >>= rhs.Get();
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator<<=(const T& rhs)
{
ValueType lhs = this->Get();
lhs <<= rhs;
this->Set(lhs);
return lhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename T>
VTKM_EXEC_CONT ValueType operator<<=(const ArrayPortalValueReference<T>& rhs)
{
ValueType lhs = this->Get();
lhs <<= rhs.Get();
this->Set(lhs);
return lhs;
}
private:
const ArrayPortalType& Portal;
vtkm::Id Index;
};
@ -121,6 +305,477 @@ void swap(typename vtkm::internal::ArrayPortalValueReference<T>::ValueType& a,
b = a;
a = tmp;
}
// The reason why all the operators on ArrayPortalValueReference are defined outside of the class
// is so that in the case that the operator in question is not defined in the value type, these
// operators will not be instantiated (and therefore cause a compile error) unless they are
// directly used (in which case a compile error is appropriate).
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator==(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() == rhs)
{
return lhs.Get() == rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator==(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() == rhs.Get())
{
return lhs.Get() == rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator==(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs == rhs.Get())
{
return lhs == rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator!=(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() != rhs)
{
return lhs.Get() != rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator!=(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() != rhs.Get())
{
return lhs.Get() != rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator!=(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs != rhs.Get())
{
return lhs != rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator<(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() < rhs)
{
return lhs.Get() < rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator<(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() < rhs.Get())
{
return lhs.Get() < rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator<(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs < rhs.Get())
{
return lhs < rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator>(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() > rhs)
{
return lhs.Get() > rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator>(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() > rhs.Get())
{
return lhs.Get() > rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator>(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs > rhs.Get())
{
return lhs > rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator<=(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() <= rhs)
{
return lhs.Get() <= rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator<=(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() <= rhs.Get())
{
return lhs.Get() <= rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator<=(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs <= rhs.Get())
{
return lhs <= rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator>=(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() >= rhs)
{
return lhs.Get() >= rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator>=(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() >= rhs.Get())
{
return lhs.Get() >= rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator>=(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs >= rhs.Get())
{
return lhs >= rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator+(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() + rhs)
{
return lhs.Get() + rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator+(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() + rhs.Get())
{
return lhs.Get() + rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator+(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs + rhs.Get())
{
return lhs + rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator-(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() - rhs)
{
return lhs.Get() - rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator-(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() - rhs.Get())
{
return lhs.Get() - rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator-(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs - rhs.Get())
{
return lhs - rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator*(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() * rhs)
{
return lhs.Get() * rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator*(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() * rhs.Get())
{
return lhs.Get() * rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator*(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs * rhs.Get())
{
return lhs * rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator/(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() / rhs)
{
return lhs.Get() / rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator/(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() / rhs.Get())
{
return lhs.Get() / rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator/(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs / rhs.Get())
{
return lhs / rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator%(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() % rhs)
{
return lhs.Get() % rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator%(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() % rhs.Get())
{
return lhs.Get() % rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator%(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs % rhs.Get())
{
return lhs % rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator^(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() ^ rhs)
{
return lhs.Get() ^ rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator^(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() ^ rhs.Get())
{
return lhs.Get() ^ rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator^(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs ^ rhs.Get())
{
return lhs ^ rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator|(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() | rhs)
{
return lhs.Get() | rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator|(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() | rhs.Get())
{
return lhs.Get() | rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator|(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs | rhs.Get())
{
return lhs | rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator&(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() & rhs)
{
return lhs.Get() & rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator&(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() & rhs.Get())
{
return lhs.Get() & rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator&(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs & rhs.Get())
{
return lhs & rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator<<(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() << rhs)
{
return lhs.Get() << rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator<<(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() << rhs.Get())
{
return lhs.Get() << rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator<<(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs << rhs.Get())
{
return lhs << rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator>>(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() >> rhs)
{
return lhs.Get() >> rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator>>(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() >> rhs.Get())
{
return lhs.Get() >> rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator>>(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs >> rhs.Get())
{
return lhs >> rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename PortalType>
VTKM_EXEC_CONT auto operator~(const ArrayPortalValueReference<PortalType>& ref)
-> decltype(~ref.Get())
{
return ~ref.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename PortalType>
VTKM_EXEC_CONT auto operator!(const ArrayPortalValueReference<PortalType>& ref)
-> decltype(!ref.Get())
{
return !ref.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator&&(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() && rhs)
{
return lhs.Get() && rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator&&(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() && rhs.Get())
{
return lhs.Get() && rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator&&(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs && rhs.Get())
{
return lhs && rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType>
VTKM_EXEC_CONT auto operator||(const ArrayPortalValueReference<LhsPortalType>& lhs,
const typename LhsPortalType::ValueType& rhs)
-> decltype(lhs.Get() || rhs)
{
return lhs.Get() || rhs;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename LhsPortalType, typename RhsPortalType>
VTKM_EXEC_CONT auto operator||(const ArrayPortalValueReference<LhsPortalType>& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs.Get() || rhs.Get())
{
return lhs.Get() || rhs.Get();
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <typename RhsPortalType>
VTKM_EXEC_CONT auto operator||(const typename RhsPortalType::ValueType& lhs,
const ArrayPortalValueReference<RhsPortalType>& rhs)
-> decltype(lhs || rhs.Get())
{
return lhs || rhs.Get();
}
}
} // namespace vtkm::internal

@ -22,11 +22,15 @@
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/TypeTraits.h>
#include <vtkm/cont/testing/Testing.h>
namespace
{
static constexpr vtkm::Id ARRAY_SIZE = 10;
template <typename ArrayPortalType>
void SetReference(vtkm::Id index, vtkm::internal::ArrayPortalValueReference<ArrayPortalType> ref)
{
@ -41,7 +45,204 @@ void CheckReference(vtkm::Id index, vtkm::internal::ArrayPortalValueReference<Ar
VTKM_TEST_ASSERT(test_equal(ref, TestValue(index, ValueType())), "Got bad value from reference.");
}
static constexpr vtkm::Id ARRAY_SIZE = 10;
template <typename ArrayPortalType>
void TryOperatorsNoVec(vtkm::Id index,
vtkm::internal::ArrayPortalValueReference<ArrayPortalType> ref,
vtkm::TypeTraitsScalarTag)
{
using ValueType = typename ArrayPortalType::ValueType;
ValueType expected = TestValue(index, ValueType());
VTKM_TEST_ASSERT(ref.Get() == expected, "Reference did not start out as expected.");
VTKM_TEST_ASSERT(!(ref < ref));
VTKM_TEST_ASSERT(ref < ValueType(expected + ValueType(1)));
VTKM_TEST_ASSERT(ValueType(expected - ValueType(1)) < ref);
VTKM_TEST_ASSERT(!(ref > ref));
VTKM_TEST_ASSERT(ref > ValueType(expected - ValueType(1)));
VTKM_TEST_ASSERT(ValueType(expected + ValueType(1)) > ref);
VTKM_TEST_ASSERT(ref <= ref);
VTKM_TEST_ASSERT(ref <= ValueType(expected + ValueType(1)));
VTKM_TEST_ASSERT(ValueType(expected - ValueType(1)) <= ref);
VTKM_TEST_ASSERT(ref >= ref);
VTKM_TEST_ASSERT(ref >= ValueType(expected - ValueType(1)));
VTKM_TEST_ASSERT(ValueType(expected + ValueType(1)) >= ref);
}
template <typename ArrayPortalType>
void TryOperatorsNoVec(vtkm::Id,
vtkm::internal::ArrayPortalValueReference<ArrayPortalType>,
vtkm::TypeTraitsVectorTag)
{
}
template <typename ArrayPortalType>
void TryOperatorsInt(vtkm::Id index,
vtkm::internal::ArrayPortalValueReference<ArrayPortalType> ref,
vtkm::TypeTraitsScalarTag,
vtkm::TypeTraitsIntegerTag)
{
using ValueType = typename ArrayPortalType::ValueType;
const ValueType operand = TestValue(ARRAY_SIZE, ValueType());
ValueType expected = TestValue(index, ValueType());
VTKM_TEST_ASSERT(ref.Get() == expected, "Reference did not start out as expected.");
VTKM_TEST_ASSERT((ref % ref) == (expected % expected));
VTKM_TEST_ASSERT((ref % expected) == (expected % expected));
VTKM_TEST_ASSERT((expected % ref) == (expected % expected));
VTKM_TEST_ASSERT((ref ^ ref) == (expected ^ expected));
VTKM_TEST_ASSERT((ref ^ expected) == (expected ^ expected));
VTKM_TEST_ASSERT((expected ^ ref) == (expected ^ expected));
VTKM_TEST_ASSERT((ref | ref) == (expected | expected));
VTKM_TEST_ASSERT((ref | expected) == (expected | expected));
VTKM_TEST_ASSERT((expected | ref) == (expected | expected));
VTKM_TEST_ASSERT((ref & ref) == (expected & expected));
VTKM_TEST_ASSERT((ref & expected) == (expected & expected));
VTKM_TEST_ASSERT((expected & ref) == (expected & expected));
VTKM_TEST_ASSERT((ref << ref) == (expected << expected));
VTKM_TEST_ASSERT((ref << expected) == (expected << expected));
VTKM_TEST_ASSERT((expected << ref) == (expected << expected));
VTKM_TEST_ASSERT((ref << ref) == (expected << expected));
VTKM_TEST_ASSERT((ref << expected) == (expected << expected));
VTKM_TEST_ASSERT((expected << ref) == (expected << expected));
VTKM_TEST_ASSERT(~ref == ~expected);
VTKM_TEST_ASSERT(!(!ref));
VTKM_TEST_ASSERT(ref && ref);
VTKM_TEST_ASSERT(ref && expected);
VTKM_TEST_ASSERT(expected && ref);
VTKM_TEST_ASSERT(ref || ref);
VTKM_TEST_ASSERT(ref || expected);
VTKM_TEST_ASSERT(expected || ref);
ref &= ref;
expected &= expected;
VTKM_TEST_ASSERT(ref == expected);
ref &= operand;
expected &= operand;
VTKM_TEST_ASSERT(ref == expected);
ref |= ref;
expected |= expected;
VTKM_TEST_ASSERT(ref == expected);
ref |= operand;
expected |= operand;
VTKM_TEST_ASSERT(ref == expected);
ref >>= ref;
expected >>= expected;
VTKM_TEST_ASSERT(ref == expected);
ref >>= operand;
expected >>= operand;
VTKM_TEST_ASSERT(ref == expected);
ref <<= ref;
expected <<= expected;
VTKM_TEST_ASSERT(ref == expected);
ref <<= operand;
expected <<= operand;
VTKM_TEST_ASSERT(ref == expected);
ref ^= ref;
expected ^= expected;
VTKM_TEST_ASSERT(ref == expected);
ref ^= operand;
expected ^= operand;
VTKM_TEST_ASSERT(ref == expected);
}
template <typename ArrayPortalType, typename DimTag, typename NumericTag>
void TryOperatorsInt(vtkm::Id,
vtkm::internal::ArrayPortalValueReference<ArrayPortalType>,
DimTag,
NumericTag)
{
}
template <typename ArrayPortalType>
void TryOperators(vtkm::Id index, vtkm::internal::ArrayPortalValueReference<ArrayPortalType> ref)
{
using ValueType = typename ArrayPortalType::ValueType;
const ValueType operand = TestValue(ARRAY_SIZE, ValueType());
ValueType expected = TestValue(index, ValueType());
VTKM_TEST_ASSERT(ref.Get() == expected, "Reference did not start out as expected.");
// Test comparison operators.
VTKM_TEST_ASSERT(ref == ref);
VTKM_TEST_ASSERT(ref == expected);
VTKM_TEST_ASSERT(expected == ref);
VTKM_TEST_ASSERT(!(ref != ref));
VTKM_TEST_ASSERT(!(ref != expected));
VTKM_TEST_ASSERT(!(expected != ref));
TryOperatorsNoVec(index, ref, typename vtkm::TypeTraits<ValueType>::DimensionalityTag());
VTKM_TEST_ASSERT((ref + ref) == (expected + expected));
VTKM_TEST_ASSERT((ref + expected) == (expected + expected));
VTKM_TEST_ASSERT((expected + ref) == (expected + expected));
VTKM_TEST_ASSERT((ref - ref) == (expected - expected));
VTKM_TEST_ASSERT((ref - expected) == (expected - expected));
VTKM_TEST_ASSERT((expected - ref) == (expected - expected));
VTKM_TEST_ASSERT((ref * ref) == (expected * expected));
VTKM_TEST_ASSERT((ref * expected) == (expected * expected));
VTKM_TEST_ASSERT((expected * ref) == (expected * expected));
VTKM_TEST_ASSERT((ref / ref) == (expected / expected));
VTKM_TEST_ASSERT((ref / expected) == (expected / expected));
VTKM_TEST_ASSERT((expected / ref) == (expected / expected));
ref += ref;
expected += expected;
VTKM_TEST_ASSERT(ref == expected);
ref += operand;
expected += operand;
VTKM_TEST_ASSERT(ref == expected);
ref -= ref;
expected -= expected;
VTKM_TEST_ASSERT(ref == expected);
ref -= operand;
expected -= operand;
VTKM_TEST_ASSERT(ref == expected);
ref *= ref;
expected *= expected;
VTKM_TEST_ASSERT(ref == expected);
ref *= operand;
expected *= operand;
VTKM_TEST_ASSERT(ref == expected);
ref /= ref;
expected /= expected;
VTKM_TEST_ASSERT(ref == expected);
ref /= operand;
expected /= operand;
VTKM_TEST_ASSERT(ref == expected);
// Reset ref
ref = TestValue(index, ValueType());
TryOperatorsInt(index,
ref,
typename vtkm::TypeTraits<ValueType>::DimensionalityTag(),
typename vtkm::TypeTraits<ValueType>::NumericTag());
}
struct DoTestForType
{
@ -54,7 +255,7 @@ struct DoTestForType
std::cout << "Set array using reference" << std::endl;
using PortalType = typename vtkm::cont::ArrayHandle<ValueType>::PortalControl;
PortalType portal = array.GetPortalControl();
for (vtkm::Id index = 0; index < ARRAY_SIZE; index++)
for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index)
{
SetReference(index, vtkm::internal::ArrayPortalValueReference<PortalType>(portal, index));
}
@ -63,10 +264,17 @@ struct DoTestForType
CheckPortal(portal);
std::cout << "Check references in set array." << std::endl;
for (vtkm::Id index = 0; index < ARRAY_SIZE; index++)
for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index)
{
CheckReference(index, vtkm::internal::ArrayPortalValueReference<PortalType>(portal, index));
}
std::cout << "Check that operators work." << std::endl;
// Start at 1 to avoid issues with 0.
for (vtkm::Id index = 1; index < ARRAY_SIZE; ++index)
{
TryOperators(index, vtkm::internal::ArrayPortalValueReference<PortalType>(portal, index));
}
}
};