Add Variant::IsType

The `Variant` class was missing a way to check the type. You could do it
indirectly using `variant.GetIndex() == variant.GetIndexOf<T>()`, but
having this convenience function is more clear.
This commit is contained in:
Kenneth Moreland 2022-08-05 12:35:57 -06:00
parent 3e794d16bc
commit 2271d0ef45
2 changed files with 13 additions and 3 deletions

@ -353,6 +353,8 @@ void TestGet()
VariantType variant = expectedValue; VariantType variant = expectedValue;
VTKM_TEST_ASSERT(variant.GetIndex() == 2); VTKM_TEST_ASSERT(variant.GetIndex() == 2);
VTKM_TEST_ASSERT(variant.IsType<vtkm::Id>());
VTKM_TEST_ASSERT(!variant.IsType<vtkm::Float32>());
VTKM_TEST_ASSERT(variant.Get<2>() == expectedValue); VTKM_TEST_ASSERT(variant.Get<2>() == expectedValue);

@ -352,6 +352,14 @@ public:
return (this->Index >= 0) && (this->Index < NumberOfTypes); return (this->Index >= 0) && (this->Index < NumberOfTypes);
} }
/// Returns true if this `Variant` stores the given type
///
template <typename T>
VTK_M_DEVICE bool IsType() const
{
return (this->GetIndex() == this->GetIndexOf<T>());
}
Variant() = default; Variant() = default;
~Variant() = default; ~Variant() = default;
Variant(const Variant&) = default; Variant(const Variant&) = default;
@ -373,7 +381,7 @@ public:
template <typename T> template <typename T>
VTK_M_DEVICE Variant& operator=(const T& src) VTK_M_DEVICE Variant& operator=(const T& src)
{ {
if (this->GetIndex() == this->GetIndexOf<T>()) if (this->IsType<T>())
{ {
this->Get<T>() = src; this->Get<T>() = src;
} }
@ -474,14 +482,14 @@ private:
template <typename T> template <typename T>
VTK_M_DEVICE T& GetImpl(std::true_type) VTK_M_DEVICE T& GetImpl(std::true_type)
{ {
VTKM_ASSERT(this->GetIndexOf<T>() == this->GetIndex()); VTKM_ASSERT(this->IsType<T>());
return detail::VariantUnionGet<IndexOf<T>::value>(this->Storage); return detail::VariantUnionGet<IndexOf<T>::value>(this->Storage);
} }
template <typename T> template <typename T>
VTK_M_DEVICE const T& GetImpl(std::true_type) const VTK_M_DEVICE const T& GetImpl(std::true_type) const
{ {
VTKM_ASSERT(this->GetIndexOf<T>() == this->GetIndex()); VTKM_ASSERT(this->IsType<T>());
return detail::VariantUnionGet<IndexOf<T>::value>(this->Storage); return detail::VariantUnionGet<IndexOf<T>::value>(this->Storage);
} }