Add RuntimeDeviceTracker::CopyState

It is sometimes the case that you want to copy the state of one
`RuntimeDeviceTracker` to another. This is particularly the case when
creating threads in the control environment. Each thread has its own
copy of `RuntimeDeviceTracker`, so when you spawn a thread you probably
want to copy the state of the tracker from the calling thread.
This commit is contained in:
Kenneth Moreland 2021-07-27 12:02:13 -06:00
parent bf6d6ca517
commit 6241179631
3 changed files with 36 additions and 6 deletions

@ -164,6 +164,13 @@ void RuntimeDeviceTracker::ForceDevice(DeviceAdapterId deviceId)
}
}
VTKM_CONT void RuntimeDeviceTracker::CopyStateFrom(const vtkm::cont::RuntimeDeviceTracker& tracker)
{
std::copy(std::cbegin(tracker.Internals->RuntimeAllowed),
std::cend(tracker.Internals->RuntimeAllowed),
std::begin(this->Internals->RuntimeAllowed));
}
VTKM_CONT
void RuntimeDeviceTracker::PrintSummary(std::ostream& out) const
{

@ -85,14 +85,14 @@ public:
VTKM_CONT
void Reset();
/// \brief Disable the given device
/// \brief Disable the given device.
///
/// The main intention of \c RuntimeDeviceTracker is to keep track of what
/// devices are working for VTK-m. However, it can also be used to turn
/// devices on and off. Use this method to disable (turn off) a given device.
/// Use \c ResetDevice to turn the device back on (if it is supported).
///
/// Passing DeviceAdapterTagAny to this will disable all devices
/// Passing DeviceAdapterTagAny to this will disable all devices.
///
VTKM_CONT void DisableDevice(DeviceAdapterId deviceId);
@ -110,6 +110,13 @@ public:
///
VTKM_CONT void ForceDevice(DeviceAdapterId deviceId);
/// \brief Copyies the state from the given device.
///
/// This is a convenient way to allow the `RuntimeDeviceTracker` on one thread
/// copy the behavior from another thread.
///
VTKM_CONT void CopyStateFrom(const vtkm::cont::RuntimeDeviceTracker& tracker);
VTKM_CONT void PrintSummary(std::ostream& out) const;
private:

@ -20,12 +20,13 @@
#include <algorithm>
#include <array>
#include <thread>
namespace
{
template <typename DeviceAdapterTag>
void verify_state(DeviceAdapterTag tag, std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& defaults)
void verify_state(vtkm::cont::DeviceAdapterId tag,
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& defaults)
{
auto& tracker = vtkm::cont::GetRuntimeDeviceTracker();
// presumable all other devices match the defaults
@ -40,8 +41,20 @@ void verify_state(DeviceAdapterTag tag, std::array<bool, VTKM_MAX_DEVICE_ADAPTER
}
}
template <typename DeviceAdapterTag>
void verify_srdt_support(DeviceAdapterTag tag,
void verify_state_thread(vtkm::cont::DeviceAdapterId tag,
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& defaults,
const vtkm::cont::RuntimeDeviceTracker& tracker)
{
// Each thread has its own RuntimeDeviceTracker (to allow you to control different devices
// on different threads). But that means that each thread creates its own tracker. We
// want all the threads to respect the runtime set up on the main thread, so copy the state
// of that tracker (passed as an argument) to this thread.
vtkm::cont::GetRuntimeDeviceTracker().CopyStateFrom(tracker);
verify_state(tag, defaults);
}
void verify_srdt_support(vtkm::cont::DeviceAdapterId tag,
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& force,
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& enable,
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& disable)
@ -54,6 +67,7 @@ void verify_srdt_support(DeviceAdapterTag tag,
vtkm::cont::RuntimeDeviceTrackerMode::Force);
VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == haveSupport, "");
verify_state(tag, force);
std::thread(verify_state_thread, tag, std::ref(force), std::ref(tracker)).join();
}
if (haveSupport)
@ -62,6 +76,7 @@ void verify_srdt_support(DeviceAdapterTag tag,
vtkm::cont::RuntimeDeviceTrackerMode::Enable);
VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == haveSupport);
verify_state(tag, enable);
std::thread(verify_state_thread, tag, std::ref(enable), std::ref(tracker)).join();
}
{
@ -69,6 +84,7 @@ void verify_srdt_support(DeviceAdapterTag tag,
vtkm::cont::RuntimeDeviceTrackerMode::Disable);
VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == false, "");
verify_state(tag, disable);
std::thread(verify_state_thread, tag, std::ref(disable), std::ref(tracker)).join();
}
}