diff --git a/vtkm/cont/RuntimeDeviceTracker.cxx b/vtkm/cont/RuntimeDeviceTracker.cxx index d1c97a0cf..b1c99d4dc 100644 --- a/vtkm/cont/RuntimeDeviceTracker.cxx +++ b/vtkm/cont/RuntimeDeviceTracker.cxx @@ -164,6 +164,13 @@ void RuntimeDeviceTracker::ForceDevice(DeviceAdapterId deviceId) } } +VTKM_CONT void RuntimeDeviceTracker::CopyStateFrom(const vtkm::cont::RuntimeDeviceTracker& tracker) +{ + std::copy(std::cbegin(tracker.Internals->RuntimeAllowed), + std::cend(tracker.Internals->RuntimeAllowed), + std::begin(this->Internals->RuntimeAllowed)); +} + VTKM_CONT void RuntimeDeviceTracker::PrintSummary(std::ostream& out) const { diff --git a/vtkm/cont/RuntimeDeviceTracker.h b/vtkm/cont/RuntimeDeviceTracker.h index 9d594e2ec..d964c2bb3 100644 --- a/vtkm/cont/RuntimeDeviceTracker.h +++ b/vtkm/cont/RuntimeDeviceTracker.h @@ -85,14 +85,14 @@ public: VTKM_CONT void Reset(); - /// \brief Disable the given device + /// \brief Disable the given device. /// /// The main intention of \c RuntimeDeviceTracker is to keep track of what /// devices are working for VTK-m. However, it can also be used to turn /// devices on and off. Use this method to disable (turn off) a given device. /// Use \c ResetDevice to turn the device back on (if it is supported). /// - /// Passing DeviceAdapterTagAny to this will disable all devices + /// Passing DeviceAdapterTagAny to this will disable all devices. /// VTKM_CONT void DisableDevice(DeviceAdapterId deviceId); @@ -110,6 +110,13 @@ public: /// VTKM_CONT void ForceDevice(DeviceAdapterId deviceId); + /// \brief Copyies the state from the given device. + /// + /// This is a convenient way to allow the `RuntimeDeviceTracker` on one thread + /// copy the behavior from another thread. + /// + VTKM_CONT void CopyStateFrom(const vtkm::cont::RuntimeDeviceTracker& tracker); + VTKM_CONT void PrintSummary(std::ostream& out) const; private: diff --git a/vtkm/cont/testing/UnitTestScopedRuntimeDeviceTracker.cxx b/vtkm/cont/testing/UnitTestScopedRuntimeDeviceTracker.cxx index 05d201597..2f9e00c13 100644 --- a/vtkm/cont/testing/UnitTestScopedRuntimeDeviceTracker.cxx +++ b/vtkm/cont/testing/UnitTestScopedRuntimeDeviceTracker.cxx @@ -20,12 +20,13 @@ #include #include +#include namespace { -template -void verify_state(DeviceAdapterTag tag, std::array& defaults) +void verify_state(vtkm::cont::DeviceAdapterId tag, + std::array& defaults) { auto& tracker = vtkm::cont::GetRuntimeDeviceTracker(); // presumable all other devices match the defaults @@ -40,8 +41,20 @@ void verify_state(DeviceAdapterTag tag, std::array -void verify_srdt_support(DeviceAdapterTag tag, +void verify_state_thread(vtkm::cont::DeviceAdapterId tag, + std::array& defaults, + const vtkm::cont::RuntimeDeviceTracker& tracker) +{ + // Each thread has its own RuntimeDeviceTracker (to allow you to control different devices + // on different threads). But that means that each thread creates its own tracker. We + // want all the threads to respect the runtime set up on the main thread, so copy the state + // of that tracker (passed as an argument) to this thread. + vtkm::cont::GetRuntimeDeviceTracker().CopyStateFrom(tracker); + + verify_state(tag, defaults); +} + +void verify_srdt_support(vtkm::cont::DeviceAdapterId tag, std::array& force, std::array& enable, std::array& disable) @@ -54,6 +67,7 @@ void verify_srdt_support(DeviceAdapterTag tag, vtkm::cont::RuntimeDeviceTrackerMode::Force); VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == haveSupport, ""); verify_state(tag, force); + std::thread(verify_state_thread, tag, std::ref(force), std::ref(tracker)).join(); } if (haveSupport) @@ -62,6 +76,7 @@ void verify_srdt_support(DeviceAdapterTag tag, vtkm::cont::RuntimeDeviceTrackerMode::Enable); VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == haveSupport); verify_state(tag, enable); + std::thread(verify_state_thread, tag, std::ref(enable), std::ref(tracker)).join(); } { @@ -69,6 +84,7 @@ void verify_srdt_support(DeviceAdapterTag tag, vtkm::cont::RuntimeDeviceTrackerMode::Disable); VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == false, ""); verify_state(tag, disable); + std::thread(verify_state_thread, tag, std::ref(disable), std::ref(tracker)).join(); } }