Use specialized class instead of function overload for Variant::Get

Nvcc was having troubles resolving the return type of the overloaded
function to get a value out of a `VariantUnion`. Replace the
implementation with a class with specializations. This is more verbose,
but easier on the compiler.
This commit is contained in:
Kenneth Moreland 2021-03-23 13:30:55 -06:00
parent c9bcdd0195
commit cb60401a63
2 changed files with 144 additions and 52 deletions

@ -1784,8 +1784,12 @@ struct VariantUnionGetImpl;
template <typename UnionType>
struct VariantUnionGetImpl<0, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V0)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V0);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V0;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V0;
}
@ -1794,8 +1798,12 @@ struct VariantUnionGetImpl<0, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<1, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V1)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V1);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V1;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V1;
}
@ -1804,8 +1812,12 @@ struct VariantUnionGetImpl<1, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<2, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V2)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V2);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V2;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V2;
}
@ -1814,8 +1826,12 @@ struct VariantUnionGetImpl<2, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<3, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V3)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V3);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V3;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V3;
}
@ -1824,8 +1840,12 @@ struct VariantUnionGetImpl<3, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<4, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V4)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V4);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V4;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V4;
}
@ -1834,8 +1854,12 @@ struct VariantUnionGetImpl<4, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<5, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V5)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V5);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V5;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V5;
}
@ -1844,8 +1868,12 @@ struct VariantUnionGetImpl<5, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<6, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V6)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V6);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V6;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V6;
}
@ -1854,8 +1882,12 @@ struct VariantUnionGetImpl<6, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<7, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V7)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V7);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V7;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V7;
}
@ -1864,8 +1896,12 @@ struct VariantUnionGetImpl<7, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<8, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V8)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V8);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V8;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V8;
}
@ -1874,8 +1910,12 @@ struct VariantUnionGetImpl<8, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<9, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V9)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V9);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V9;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V9;
}
@ -1884,8 +1924,12 @@ struct VariantUnionGetImpl<9, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<10, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V10)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V10);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V10;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V10;
}
@ -1894,8 +1938,12 @@ struct VariantUnionGetImpl<10, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<11, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V11)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V11);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V11;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V11;
}
@ -1904,8 +1952,12 @@ struct VariantUnionGetImpl<11, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<12, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V12)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V12);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V12;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V12;
}
@ -1914,8 +1966,12 @@ struct VariantUnionGetImpl<12, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<13, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V13)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V13);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V13;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V13;
}
@ -1924,8 +1980,12 @@ struct VariantUnionGetImpl<13, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<14, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V14)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V14);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V14;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V14;
}
@ -1934,8 +1994,12 @@ struct VariantUnionGetImpl<14, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<15, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V15)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V15);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V15;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V15;
}
@ -1944,8 +2008,12 @@ struct VariantUnionGetImpl<15, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<16, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V16)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V16);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V16;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V16;
}
@ -1954,8 +2022,12 @@ struct VariantUnionGetImpl<16, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<17, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V17)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V17);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V17;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V17;
}
@ -1964,8 +2036,12 @@ struct VariantUnionGetImpl<17, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<18, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V18)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V18);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V18;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V18;
}
@ -1974,8 +2050,12 @@ struct VariantUnionGetImpl<18, UnionType>
template <typename UnionType>
struct VariantUnionGetImpl<19, UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V19)&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V19);
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V19;
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V19;
}
@ -1986,19 +2066,23 @@ template <vtkm::IdComponent I, typename UnionType>
struct VariantUnionGetImpl
{
VTKM_STATIC_ASSERT(I >= 20);
using RecursiveGet = VariantUnionGetImpl<I - 20, decltype(std::declval<UnionType>().Remaining)>;
using RecursiveGet = VariantUnionGetImpl<I - 20, decltype(std::declval<UnionType&>().Remaining)>;
using ReturnType = typename RecursiveGet::ReturnType;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return RecursiveGet::Get(storage.Remaining);
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return RecursiveGet::Get(storage.Remaining);
}
};
template <vtkm::IdComponent I, typename UnionType>
VTK_M_DEVICE typename VariantUnionGetImpl<I, UnionType>::ReturnType
VariantUnionGet(UnionType& storage) noexcept
VTK_M_DEVICE auto VariantUnionGet(UnionType& storage) noexcept
-> decltype(VariantUnionGetImpl<I, typename std::decay<UnionType>::type>::Get(storage))&
{
return VariantUnionGetImpl<I, UnionType>::Get(storage);
return VariantUnionGetImpl<I, typename std::decay<UnionType>::type>::Get(storage);
}
// --------------------------------------------------------------------------------

@ -256,8 +256,12 @@ $for(param_index in range(max_expanded))\
template <typename UnionType>
struct VariantUnionGetImpl<$(param_index), UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V$(param_index))&;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
using ReturnType = decltype(std::declval<UnionType>().V$(param_index));
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V$(param_index);
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V$(param_index);
}
@ -269,19 +273,23 @@ template <vtkm::IdComponent I, typename UnionType>
struct VariantUnionGetImpl
{
VTKM_STATIC_ASSERT(I >= $(max_expanded));
using RecursiveGet = VariantUnionGetImpl<I - $(max_expanded), decltype(std::declval<UnionType>().Remaining)>;
using RecursiveGet = VariantUnionGetImpl<I - $(max_expanded), decltype(std::declval<UnionType&>().Remaining)>;
using ReturnType = typename RecursiveGet::ReturnType;
VTK_M_DEVICE static ReturnType Get(UnionType& storage) noexcept
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return RecursiveGet::Get(storage.Remaining);
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return RecursiveGet::Get(storage.Remaining);
}
};
template <vtkm::IdComponent I, typename UnionType>
VTK_M_DEVICE typename VariantUnionGetImpl<I, UnionType>::ReturnType
VariantUnionGet(UnionType& storage) noexcept
VTK_M_DEVICE auto VariantUnionGet(UnionType& storage) noexcept
-> decltype(VariantUnionGetImpl<I, typename std::decay<UnionType>::type>::Get(storage))&
{
return VariantUnionGetImpl<I, UnionType>::Get(storage);
return VariantUnionGetImpl<I, typename std::decay<UnionType>::type>::Get(storage);
}
// --------------------------------------------------------------------------------