Save device choice on spawned control threads

The `RuntimeDeviceTracker` is a thread-local variable that monitors the
devices to use separately on each thread. This is an important feature
to allow different threads to control different devices.

When a tracker is created on a new thread, it was simply reset, which
makes sense. However, the reset does not take into account the device
selected by `vtkm::cont::Initialize`. This means that if VTK-m was used
in a different thread than it was initialized, it would ignore the
`--vtkm-device` parameter.

To get around this problem, keep track of the `RuntimeDeviceTracker` on
the "main" thread. When a `RuntimeDeviceTracker` is created on a new
thread, it copies the state from that tracker.
This commit is contained in:
Kenneth Moreland 2021-07-29 16:45:43 -06:00
parent 822eb8f165
commit 3feff36891
3 changed files with 113 additions and 20 deletions

@ -276,32 +276,41 @@ ScopedRuntimeDeviceTracker::~ScopedRuntimeDeviceTracker()
VTKM_CONT
vtkm::cont::RuntimeDeviceTracker& GetRuntimeDeviceTracker()
{
#if defined(VTKM_CLANG) && defined(__apple_build_version__) && (__apple_build_version__ < 8000000)
static std::mutex mtx;
static std::map<std::thread::id, vtkm::cont::RuntimeDeviceTracker*> globalTrackers;
static std::map<std::thread::id, vtkm::cont::detail::RuntimeDeviceTrackerInternals*>
globalTrackerInternals;
std::thread::id this_id = std::this_thread::get_id();
using SharedTracker = std::shared_ptr<vtkm::cont::RuntimeDeviceTracker>;
static thread_local vtkm::cont::detail::RuntimeDeviceTrackerInternals details;
static thread_local SharedTracker runtimeDeviceTracker;
static std::weak_ptr<vtkm::cont::RuntimeDeviceTracker> defaultRuntimeDeviceTracker;
std::unique_lock<std::mutex> lock(mtx);
auto iter = globalTrackers.find(this_id);
if (iter != globalTrackers.end())
if (runtimeDeviceTracker)
{
return *iter->second;
return *runtimeDeviceTracker;
}
// The RuntimeDeviceTracker for this thread has not been created. Create a new one.
runtimeDeviceTracker = SharedTracker(new vtkm::cont::RuntimeDeviceTracker(&details, true));
// Get the default details, which are a global variable, with thread safety
static std::mutex mutex;
std::unique_lock<std::mutex> lock(mutex);
SharedTracker defaultTracker = defaultRuntimeDeviceTracker.lock();
if (defaultTracker)
{
// We already have a default tracker, so copy the state from there. We don't need to
// keep our mutex locked because we already have a safe handle to the defaultTracker.
lock.unlock();
runtimeDeviceTracker->CopyStateFrom(*defaultTracker);
}
else
{
auto* details = new vtkm::cont::detail::RuntimeDeviceTrackerInternals();
vtkm::cont::RuntimeDeviceTracker* tracker = new vtkm::cont::RuntimeDeviceTracker(details, true);
globalTrackers[this_id] = tracker;
globalTrackerInternals[this_id] = details;
return *tracker;
// There is no default tracker yet. It has never been created (or it was on a thread
// that got deleted). Use the current thread's details as the default.
defaultRuntimeDeviceTracker = runtimeDeviceTracker;
}
#else
static thread_local vtkm::cont::detail::RuntimeDeviceTrackerInternals details;
static thread_local vtkm::cont::RuntimeDeviceTracker runtimeDeviceTracker(&details, true);
return runtimeDeviceTracker;
#endif
return *runtimeDeviceTracker;
}
}
} // namespace vtkm::cont

@ -73,6 +73,7 @@ set(unit_tests
UnitTestDataSetUniform.cxx
UnitTestDeviceAdapterAlgorithmDependency.cxx
UnitTestDeviceAdapterAlgorithmGeneral.cxx
UnitTestDeviceSelectOnThreads.cxx
UnitTestDynamicCellSet.cxx
UnitTestError.cxx
UnitTestFieldRangeCompute.cxx

@ -0,0 +1,83 @@
//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//============================================================================
#include <vtkm/cont/Initialize.h>
#include <vtkm/cont/RuntimeDeviceTracker.h>
#include <vtkm/cont/testing/Testing.h>
#include <future>
#include <vector>
namespace
{
bool CheckLocalRuntime()
{
if (!vtkm::cont::GetRuntimeDeviceTracker().CanRunOn(vtkm::cont::DeviceAdapterTagSerial{}))
{
std::cout << "Serial device not runable" << std::endl;
return false;
}
for (vtkm::Int8 deviceIndex = 0; deviceIndex < VTKM_MAX_DEVICE_ADAPTER_ID; ++deviceIndex)
{
vtkm::cont::DeviceAdapterId device = vtkm::cont::make_DeviceAdapterId(deviceIndex);
if (!device.IsValueValid() || (deviceIndex == VTKM_DEVICE_ADAPTER_SERIAL))
{
continue;
}
if (vtkm::cont::GetRuntimeDeviceTracker().CanRunOn(device))
{
std::cout << "Device " << device.GetName() << " declared as runnable" << std::endl;
return false;
}
}
return true;
}
void DoTest()
{
VTKM_TEST_ASSERT(CheckLocalRuntime(),
"Runtime check failed on main thread. Did you try to set a device argument?");
// Now check on a new thread. The runtime is a thread-local object so that each thread can
// use its own device. But when you start a thread, you want the default to be what the
// user selected on the main thread.
VTKM_TEST_ASSERT(std::async(std::launch::async, CheckLocalRuntime).get(),
"Runtime loses defaults in spawned thread.");
}
} // anonymous namespace
int UnitTestDeviceSelectOnThreads(int argc, char* argv[])
{
// This test is checking to make sure that a device selected in the command line
// argument is the default for all threads. We will test this by adding an argument
// to select the serial device, which is always available. The test might fail if
// a different device is also selected.
std::string deviceSelectString("--vtkm-device=serial");
std::vector<char> deviceSelectArg(deviceSelectString.size());
std::copy(deviceSelectString.begin(), deviceSelectString.end(), deviceSelectArg.begin());
deviceSelectArg.push_back('\0');
std::vector<char*> newArgs;
for (int i = 0; i < argc; ++i)
{
newArgs.push_back(argv[i]);
}
newArgs.push_back(deviceSelectArg.data());
newArgs.push_back(nullptr);
int newArgc = argc + 1;
return vtkm::cont::testing::Testing::Run(DoTest, newArgc, newArgs.data());
}