mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-09-16 17:22:55 +00:00
Save device choice on spawned control threads
The `RuntimeDeviceTracker` is a thread-local variable that monitors the devices to use separately on each thread. This is an important feature to allow different threads to control different devices. When a tracker is created on a new thread, it was simply reset, which makes sense. However, the reset does not take into account the device selected by `vtkm::cont::Initialize`. This means that if VTK-m was used in a different thread than it was initialized, it would ignore the `--vtkm-device` parameter. To get around this problem, keep track of the `RuntimeDeviceTracker` on the "main" thread. When a `RuntimeDeviceTracker` is created on a new thread, it copies the state from that tracker.
This commit is contained in:
parent
822eb8f165
commit
3feff36891
@ -276,32 +276,41 @@ ScopedRuntimeDeviceTracker::~ScopedRuntimeDeviceTracker()
|
||||
VTKM_CONT
|
||||
vtkm::cont::RuntimeDeviceTracker& GetRuntimeDeviceTracker()
|
||||
{
|
||||
#if defined(VTKM_CLANG) && defined(__apple_build_version__) && (__apple_build_version__ < 8000000)
|
||||
static std::mutex mtx;
|
||||
static std::map<std::thread::id, vtkm::cont::RuntimeDeviceTracker*> globalTrackers;
|
||||
static std::map<std::thread::id, vtkm::cont::detail::RuntimeDeviceTrackerInternals*>
|
||||
globalTrackerInternals;
|
||||
std::thread::id this_id = std::this_thread::get_id();
|
||||
using SharedTracker = std::shared_ptr<vtkm::cont::RuntimeDeviceTracker>;
|
||||
static thread_local vtkm::cont::detail::RuntimeDeviceTrackerInternals details;
|
||||
static thread_local SharedTracker runtimeDeviceTracker;
|
||||
static std::weak_ptr<vtkm::cont::RuntimeDeviceTracker> defaultRuntimeDeviceTracker;
|
||||
|
||||
std::unique_lock<std::mutex> lock(mtx);
|
||||
auto iter = globalTrackers.find(this_id);
|
||||
if (iter != globalTrackers.end())
|
||||
if (runtimeDeviceTracker)
|
||||
{
|
||||
return *iter->second;
|
||||
return *runtimeDeviceTracker;
|
||||
}
|
||||
|
||||
// The RuntimeDeviceTracker for this thread has not been created. Create a new one.
|
||||
runtimeDeviceTracker = SharedTracker(new vtkm::cont::RuntimeDeviceTracker(&details, true));
|
||||
|
||||
// Get the default details, which are a global variable, with thread safety
|
||||
static std::mutex mutex;
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
|
||||
SharedTracker defaultTracker = defaultRuntimeDeviceTracker.lock();
|
||||
|
||||
if (defaultTracker)
|
||||
{
|
||||
// We already have a default tracker, so copy the state from there. We don't need to
|
||||
// keep our mutex locked because we already have a safe handle to the defaultTracker.
|
||||
lock.unlock();
|
||||
runtimeDeviceTracker->CopyStateFrom(*defaultTracker);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto* details = new vtkm::cont::detail::RuntimeDeviceTrackerInternals();
|
||||
vtkm::cont::RuntimeDeviceTracker* tracker = new vtkm::cont::RuntimeDeviceTracker(details, true);
|
||||
globalTrackers[this_id] = tracker;
|
||||
globalTrackerInternals[this_id] = details;
|
||||
return *tracker;
|
||||
// There is no default tracker yet. It has never been created (or it was on a thread
|
||||
// that got deleted). Use the current thread's details as the default.
|
||||
defaultRuntimeDeviceTracker = runtimeDeviceTracker;
|
||||
}
|
||||
#else
|
||||
static thread_local vtkm::cont::detail::RuntimeDeviceTrackerInternals details;
|
||||
static thread_local vtkm::cont::RuntimeDeviceTracker runtimeDeviceTracker(&details, true);
|
||||
return runtimeDeviceTracker;
|
||||
#endif
|
||||
|
||||
return *runtimeDeviceTracker;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace vtkm::cont
|
||||
|
@ -73,6 +73,7 @@ set(unit_tests
|
||||
UnitTestDataSetUniform.cxx
|
||||
UnitTestDeviceAdapterAlgorithmDependency.cxx
|
||||
UnitTestDeviceAdapterAlgorithmGeneral.cxx
|
||||
UnitTestDeviceSelectOnThreads.cxx
|
||||
UnitTestDynamicCellSet.cxx
|
||||
UnitTestError.cxx
|
||||
UnitTestFieldRangeCompute.cxx
|
||||
|
83
vtkm/cont/testing/UnitTestDeviceSelectOnThreads.cxx
Normal file
83
vtkm/cont/testing/UnitTestDeviceSelectOnThreads.cxx
Normal file
@ -0,0 +1,83 @@
|
||||
//============================================================================
|
||||
// Copyright (c) Kitware, Inc.
|
||||
// All rights reserved.
|
||||
// See LICENSE.txt for details.
|
||||
//
|
||||
// This software is distributed WITHOUT ANY WARRANTY; without even
|
||||
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
|
||||
// PURPOSE. See the above copyright notice for more information.
|
||||
//============================================================================
|
||||
|
||||
#include <vtkm/cont/Initialize.h>
|
||||
#include <vtkm/cont/RuntimeDeviceTracker.h>
|
||||
|
||||
#include <vtkm/cont/testing/Testing.h>
|
||||
|
||||
#include <future>
|
||||
#include <vector>
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
bool CheckLocalRuntime()
|
||||
{
|
||||
if (!vtkm::cont::GetRuntimeDeviceTracker().CanRunOn(vtkm::cont::DeviceAdapterTagSerial{}))
|
||||
{
|
||||
std::cout << "Serial device not runable" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
for (vtkm::Int8 deviceIndex = 0; deviceIndex < VTKM_MAX_DEVICE_ADAPTER_ID; ++deviceIndex)
|
||||
{
|
||||
vtkm::cont::DeviceAdapterId device = vtkm::cont::make_DeviceAdapterId(deviceIndex);
|
||||
if (!device.IsValueValid() || (deviceIndex == VTKM_DEVICE_ADAPTER_SERIAL))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if (vtkm::cont::GetRuntimeDeviceTracker().CanRunOn(device))
|
||||
{
|
||||
std::cout << "Device " << device.GetName() << " declared as runnable" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void DoTest()
|
||||
{
|
||||
VTKM_TEST_ASSERT(CheckLocalRuntime(),
|
||||
"Runtime check failed on main thread. Did you try to set a device argument?");
|
||||
|
||||
// Now check on a new thread. The runtime is a thread-local object so that each thread can
|
||||
// use its own device. But when you start a thread, you want the default to be what the
|
||||
// user selected on the main thread.
|
||||
VTKM_TEST_ASSERT(std::async(std::launch::async, CheckLocalRuntime).get(),
|
||||
"Runtime loses defaults in spawned thread.");
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
int UnitTestDeviceSelectOnThreads(int argc, char* argv[])
|
||||
{
|
||||
// This test is checking to make sure that a device selected in the command line
|
||||
// argument is the default for all threads. We will test this by adding an argument
|
||||
// to select the serial device, which is always available. The test might fail if
|
||||
// a different device is also selected.
|
||||
std::string deviceSelectString("--vtkm-device=serial");
|
||||
std::vector<char> deviceSelectArg(deviceSelectString.size());
|
||||
std::copy(deviceSelectString.begin(), deviceSelectString.end(), deviceSelectArg.begin());
|
||||
deviceSelectArg.push_back('\0');
|
||||
|
||||
std::vector<char*> newArgs;
|
||||
for (int i = 0; i < argc; ++i)
|
||||
{
|
||||
newArgs.push_back(argv[i]);
|
||||
}
|
||||
newArgs.push_back(deviceSelectArg.data());
|
||||
newArgs.push_back(nullptr);
|
||||
|
||||
int newArgc = argc + 1;
|
||||
|
||||
return vtkm::cont::testing::Testing::Run(DoTest, newArgc, newArgs.data());
|
||||
}
|
Loading…
Reference in New Issue
Block a user