diff --git a/docs/changelog/copy-invalid-variant.md b/docs/changelog/copy-invalid-variant.md new file mode 100644 index 000000000..a72eec879 --- /dev/null +++ b/docs/changelog/copy-invalid-variant.md @@ -0,0 +1,5 @@ +# Fix bug with copying invalid variants + +There was a bug where if you attempted to copy a `Variant` that was not +valid (i.e. did not hold an object); a seg fault could happen. This has +been changed to set the target variant to also be invalid. diff --git a/vtkm/exec/internal/testing/UnitTestVariant.cxx b/vtkm/exec/internal/testing/UnitTestVariant.cxx index 0b97b6ea6..f6e5c9dd4 100644 --- a/vtkm/exec/internal/testing/UnitTestVariant.cxx +++ b/vtkm/exec/internal/testing/UnitTestVariant.cxx @@ -431,6 +431,23 @@ void TestCastAndCall() VTKM_TEST_ASSERT(test_equal(result, TestValue(26, vtkm::FloatDefault{}))); } +void TestCopyInvalid() +{ + std::cout << "Test copy invalid variant" << std::endl; + + using VariantType = vtkm::exec::internal::Variant, NonTrivial>; + + VariantType source; + source.Reset(); + + VariantType destination1(source); + VTKM_TEST_ASSERT(!destination1.IsValid()); + + VariantType destination2(TypePlaceholder<0>{}); + destination2 = source; + VTKM_TEST_ASSERT(!destination2.IsValid()); +} + struct CountConstructDestruct { vtkm::Id* Count; @@ -577,6 +594,7 @@ void RunTest() TestTriviallyCopyable(); TestGet(); TestCastAndCall(); + TestCopyInvalid(); TestCopyDestroy(); TestEmplace(); TestConstructDestruct(); diff --git a/vtkm/internal/VariantImpl.h b/vtkm/internal/VariantImpl.h index e3d08eda8..e11926865 100644 --- a/vtkm/internal/VariantImpl.h +++ b/vtkm/internal/VariantImpl.h @@ -256,21 +256,31 @@ struct VariantConstructorImpl, VTK_M_DEVICE VariantConstructorImpl(const VariantConstructorImpl& src) noexcept : VariantStorageImpl(vtkm::internal::NullType{}) { - src.CastAndCall(VariantCopyConstructFunctor{}, this->Storage); + if (src.IsValid()) + { + src.CastAndCall(VariantCopyConstructFunctor{}, this->Storage); + } this->Index = src.Index; } VTK_M_DEVICE VariantConstructorImpl& operator=(const VariantConstructorImpl& src) noexcept { - if (this->GetIndex() == src.GetIndex()) + if (src.IsValid()) { - src.CastAndCall(detail::VariantCopyFunctor{}, this->Storage); + if (this->GetIndex() == src.GetIndex()) + { + src.CastAndCall(detail::VariantCopyFunctor{}, this->Storage); + } + else + { + this->Reset(); + src.CastAndCall(detail::VariantCopyConstructFunctor{}, this->Storage); + this->Index = src.Index; + } } else { this->Reset(); - src.CastAndCall(detail::VariantCopyConstructFunctor{}, this->Storage); - this->Index = src.Index; } return *this; }