ScopedRuntimeDeviceTracker have better controls of setting devices.

The ScopedRuntimeDeviceTracker now can force, enable, or disable
devices. Additionally the ScopedRuntimeDeviceTracker and the
RuntimeDeviceTracker handle the DeviceAdapterTagAny robustly
across all methods.
This commit is contained in:
Robert Maynard 2019-05-21 15:32:10 -04:00
parent 4212d0c04f
commit bcaf7d9beb
6 changed files with 279 additions and 49 deletions

@ -63,3 +63,28 @@ goes out of scope.
}
//openmp/tbb/... are now again active
```
The `vtkm::cont::ScopedRuntimeDeviceTracker` is not limited to forcing
execution to occur on a single device. When constructed it can either force
execution to a device, disable a device or enable a device. These options
also work with the `DeviceAdapterTagAny`.
```cpp
{
//enable all devices
vtkm::cont::DeviceAdapterTagAny any;
vtkm::cont::ScopedRuntimeDeviceTracker tracker(any,
vtkm::cont::RuntimeDeviceTrackerMode::Enable);
...
}
{
//disable only cuda
vtkm::cont::DeviceAdapterTagCuda cuda;
vtkm::cont::ScopedRuntimeDeviceTracker tracker(cuda,
vtkm::cont::RuntimeDeviceTrackerMode::Disable);
...
}
```

@ -61,10 +61,24 @@ void RuntimeDeviceTracker::CheckDevice(vtkm::cont::DeviceAdapterId deviceId) con
}
VTKM_CONT
bool RuntimeDeviceTracker::CanRunOnImpl(vtkm::cont::DeviceAdapterId deviceId) const
bool RuntimeDeviceTracker::CanRunOn(vtkm::cont::DeviceAdapterId deviceId) const
{
this->CheckDevice(deviceId);
return this->Internals->RuntimeAllowed[deviceId.GetValue()];
if (deviceId == vtkm::cont::DeviceAdapterTagAny{})
{ //If at least a single device is enabled, than any device is enabled
for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
{
if (this->Internals->RuntimeAllowed[static_cast<std::size_t>(i)])
{
return true;
}
}
return false;
}
else
{
this->CheckDevice(deviceId);
return this->Internals->RuntimeAllowed[deviceId.GetValue()];
}
}
VTKM_CONT
@ -77,6 +91,21 @@ void RuntimeDeviceTracker::SetDeviceState(vtkm::cont::DeviceAdapterId deviceId,
this->Internals->RuntimeAllowed[deviceId.GetValue()] = state;
}
VTKM_CONT void RuntimeDeviceTracker::ResetDevice(vtkm::cont::DeviceAdapterId deviceId)
{
if (deviceId == vtkm::cont::DeviceAdapterTagAny{})
{
this->Reset();
}
else
{
vtkm::cont::RuntimeDeviceInformation runtimeDevice;
this->SetDeviceState(deviceId, runtimeDevice.Exists(deviceId));
}
}
VTKM_CONT
void RuntimeDeviceTracker::Reset()
{
@ -86,7 +115,7 @@ void RuntimeDeviceTracker::Reset()
// when we use logging we get better messages stating we are reseting
// the devices.
vtkm::cont::RuntimeDeviceInformation runtimeDevice;
for (vtkm::Int8 i = 0; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
{
vtkm::cont::DeviceAdapterId device = vtkm::cont::make_DeviceAdapterId(i);
if (device.IsValueValid())
@ -99,24 +128,16 @@ void RuntimeDeviceTracker::Reset()
}
}
VTKM_CONT
void RuntimeDeviceTracker::ForceDeviceImpl(vtkm::cont::DeviceAdapterId deviceId, bool runtimeExists)
VTKM_CONT void RuntimeDeviceTracker::DisableDevice(vtkm::cont::DeviceAdapterId deviceId)
{
if (!runtimeExists)
if (deviceId == vtkm::cont::DeviceAdapterTagAny{})
{
std::stringstream message;
message << "Cannot force to device '" << deviceId.GetName()
<< "' because that device is not available on this system";
throw vtkm::cont::ErrorBadValue(message.str());
std::fill_n(this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
}
else
{
this->SetDeviceState(deviceId, false);
}
this->CheckDevice(deviceId);
VTKM_LOG_S(vtkm::cont::LogLevel::Info,
"Forcing execution to occur on device '" << deviceId.GetName() << "'");
std::fill_n(this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
this->Internals->RuntimeAllowed[deviceId.GetValue()] = runtimeExists;
}
VTKM_CONT
@ -128,31 +149,71 @@ void RuntimeDeviceTracker::ForceDevice(DeviceAdapterId deviceId)
}
else
{
this->CheckDevice(deviceId);
vtkm::cont::RuntimeDeviceInformation runtimeDevice;
this->ForceDeviceImpl(deviceId, runtimeDevice.Exists(deviceId));
const bool runtimeExists = runtimeDevice.Exists(deviceId);
if (!runtimeExists)
{
std::stringstream message;
message << "Cannot force to device '" << deviceId.GetName()
<< "' because that device is not available on this system";
throw vtkm::cont::ErrorBadValue(message.str());
}
VTKM_LOG_S(vtkm::cont::LogLevel::Info,
"Forcing execution to occur on device '" << deviceId.GetName() << "'");
std::fill_n(this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
this->Internals->RuntimeAllowed[deviceId.GetValue()] = runtimeExists;
}
}
VTKM_CONT
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device)
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device,
RuntimeDeviceTrackerMode mode)
: RuntimeDeviceTracker(GetRuntimeDeviceTracker().Internals, false)
, SavedState(new detail::RuntimeDeviceTrackerInternals())
{
std::copy_n(
this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, this->SavedState->RuntimeAllowed);
this->ForceDevice(device);
if (mode == RuntimeDeviceTrackerMode::Force)
{
this->ForceDevice(device);
}
else if (mode == RuntimeDeviceTrackerMode::Enable)
{
this->ResetDevice(device);
}
else if (mode == RuntimeDeviceTrackerMode::Disable)
{
this->DisableDevice(device);
}
}
VTKM_CONT
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
vtkm::cont::DeviceAdapterId device,
RuntimeDeviceTrackerMode mode,
const vtkm::cont::RuntimeDeviceTracker& tracker)
: RuntimeDeviceTracker(tracker.Internals, false)
, SavedState(new detail::RuntimeDeviceTrackerInternals())
{
std::copy_n(
this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, this->SavedState->RuntimeAllowed);
this->ForceDevice(device);
if (mode == RuntimeDeviceTrackerMode::Force)
{
this->ForceDevice(device);
}
else if (mode == RuntimeDeviceTrackerMode::Enable)
{
this->ResetDevice(device);
}
else if (mode == RuntimeDeviceTrackerMode::Disable)
{
this->DisableDevice(device);
}
}
VTKM_CONT

@ -54,11 +54,10 @@ public:
/// Returns true if the given device adapter is supported on the current
/// machine.
///
VTKM_CONT bool CanRunOn(DeviceAdapterId device) const { return this->CanRunOnImpl(device); }
VTKM_CONT bool CanRunOn(DeviceAdapterId deviceId) const;
/// Report a failure to allocate memory on a device, this will flag the
/// device as being unusable for all future invocations of the instance of
/// the filter.
/// device as being unusable for all future invocations.
///
VTKM_CONT void ReportAllocationFailure(vtkm::cont::DeviceAdapterId deviceId,
const vtkm::cont::ErrorBadAllocation&)
@ -75,13 +74,10 @@ public:
}
/// Reset the tracker for the given device. This will discard any updates
/// caused by reported failures
/// caused by reported failures. Passing DeviceAdapterTagAny to this will
/// reset all devices ( same as \c Reset ).
///
VTKM_CONT void ResetDevice(vtkm::cont::DeviceAdapterId device)
{
vtkm::cont::RuntimeDeviceInformation runtimeDevice;
this->SetDeviceState(device, runtimeDevice.Exists(device));
}
VTKM_CONT void ResetDevice(vtkm::cont::DeviceAdapterId deviceId);
/// Reset the tracker to its default state for default devices.
/// Will discard any updates caused by reported failures.
@ -96,7 +92,9 @@ public:
/// devices on and off. Use this method to disable (turn off) a given device.
/// Use \c ResetDevice to turn the device back on (if it is supported).
///
VTKM_CONT void DisableDevice(DeviceAdapterId device) { this->SetDeviceState(device, false); }
/// Passing DeviceAdapterTagAny to this will disable all devices
///
VTKM_CONT void DisableDevice(DeviceAdapterId deviceId);
/// \brief Disable all devices except the specified one.
///
@ -110,7 +108,7 @@ public:
/// This method will throw a \c ErrorBadValue if the given device does not
/// exist on the system.
///
VTKM_CONT void ForceDevice(DeviceAdapterId id);
VTKM_CONT void ForceDevice(DeviceAdapterId deviceId);
private:
friend struct ScopedRuntimeDeviceTracker;
@ -129,14 +127,16 @@ private:
VTKM_CONT
void CheckDevice(vtkm::cont::DeviceAdapterId deviceId) const;
VTKM_CONT
bool CanRunOnImpl(vtkm::cont::DeviceAdapterId deviceId) const;
VTKM_CONT
void SetDeviceState(vtkm::cont::DeviceAdapterId deviceId, bool state);
};
VTKM_CONT
void ForceDeviceImpl(vtkm::cont::DeviceAdapterId deviceId, bool runtimeExists);
enum struct RuntimeDeviceTrackerMode
{
Force,
Enable,
Disable
};
/// A class that can be used to determine or modify which device adapter
@ -147,19 +147,39 @@ private:
///
struct VTKM_CONT_EXPORT ScopedRuntimeDeviceTracker : public vtkm::cont::RuntimeDeviceTracker
{
/// Construct a ScopedRuntimeDeviceTracker where the only active device
/// for the current thread is the one provided by the constructor. Passing
/// DeviceAdapterTagAny to this function will reset all devices to their
/// default state.
/// Construct a ScopedRuntimeDeviceTracker where the state of the active devices
/// for the current thread are determined by the parameters to the constructor.
///
/// 'Force'
/// - Force-Enable the provided single device adapter
/// - Force-Enable all device adapters when using vtkm::cont::DeviceAdaterTagAny
/// 'Enable'
/// - Enable the provided single device adapter if it was previously disabled
/// - Enable all device adapters that are currently disabled when using
/// vtkm::cont::DeviceAdaterTagAny
/// 'Disable'
/// - Disable the provided single device adapter
/// - Disable all device adapters when using vtkm::cont::DeviceAdaterTagAny
///
/// Constructor is not thread safe
VTKM_CONT ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device);
VTKM_CONT ScopedRuntimeDeviceTracker(
vtkm::cont::DeviceAdapterId device,
RuntimeDeviceTrackerMode mode = RuntimeDeviceTrackerMode::Force);
/// Construct a ScopedRuntimeDeviceTracker associated with the thread
/// associated with the provided tracker. The only active device
/// for this thread is the one provided by the constructor. Passing
/// DeviceAdapterTagAny to this function will reset all devices to their
/// default state.
/// associated with the provided tracker. The active devices
/// for the current thread are determined by the parameters to the constructor.
///
/// 'Force'
/// - Force-Enable the provided single device adapter
/// - Force-Enable all device adapters when using vtkm::cont::DeviceAdaterTagAny
/// 'Enable'
/// - Enable the provided single device adapter if it was previously disabled
/// - Enable all device adapters that are currently disabled when using
/// vtkm::cont::DeviceAdaterTagAny
/// 'Disable'
/// - Disable the provided single device adapter
/// - Disable all device adapters when using vtkm::cont::DeviceAdaterTagAny
///
/// Any modifications to the ScopedRuntimeDeviceTracker will effect what
/// ever thread the \c tracker is associated with, which might not be
@ -167,6 +187,7 @@ struct VTKM_CONT_EXPORT ScopedRuntimeDeviceTracker : public vtkm::cont::RuntimeD
///
/// Constructor is not thread safe
VTKM_CONT ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device,
RuntimeDeviceTrackerMode mode,
const vtkm::cont::RuntimeDeviceTracker& tracker);
/// Construct a ScopedRuntimeDeviceTracker associated with the thread

@ -71,6 +71,7 @@ set(unit_tests
UnitTestMultiBlock.cxx
UnitTestRuntimeDeviceInformation.cxx
UnitTestRuntimeDeviceNames.cxx
UnitTestScopedRuntimeDeviceTracker.cxx
UnitTestStorageBasic.cxx
UnitTestStorageImplicit.cxx
UnitTestStorageListTag.cxx

@ -161,7 +161,8 @@ struct TransformExecObject : public vtkm::cont::ExecutionAndControlObjectBase
{
// Need to make sure the serial device is supported, since that is what is used on the
// control side. Therefore we reset to all supported devices.
vtkm::cont::ScopedRuntimeDeviceTracker scopedTracker(vtkm::cont::DeviceAdapterTagAny{});
vtkm::cont::ScopedRuntimeDeviceTracker scopedTracker(
vtkm::cont::DeviceAdapterTagSerial{}, vtkm::cont::RuntimeDeviceTrackerMode::Enable);
this->VirtualFunctor.Reset(new VirtualTransformFunctor<ValueType, FunctorType>(functor));
}

@ -0,0 +1,121 @@
//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//============================================================================
#include <vtkm/cont/RuntimeDeviceTracker.h>
//include all backends
#include <vtkm/cont/cuda/DeviceAdapterCuda.h>
#include <vtkm/cont/openmp/DeviceAdapterOpenMP.h>
#include <vtkm/cont/serial/DeviceAdapterSerial.h>
#include <vtkm/cont/tbb/DeviceAdapterTBB.h>
#include <vtkm/cont/testing/Testing.h>
#include <algorithm>
#include <array>
namespace
{
template <typename DeviceAdapterTag>
void verify_state(DeviceAdapterTag tag, std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& defaults)
{
auto& tracker = vtkm::cont::GetRuntimeDeviceTracker();
// presumable all other devices match the defaults
for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
{
const auto deviceId = vtkm::cont::make_DeviceAdapterId(i);
if (deviceId != tag)
{
VTKM_TEST_ASSERT(defaults[static_cast<std::size_t>(i)] == tracker.CanRunOn(deviceId),
"ScopedRuntimeDeviceTracker didn't properly setup state correctly");
}
}
}
template <typename DeviceAdapterTag>
void verify_srdt_support(DeviceAdapterTag tag,
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& force,
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& enable,
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID>& disable)
{
vtkm::cont::RuntimeDeviceInformation runtime;
const bool haveSupport = runtime.Exists(tag);
if (haveSupport)
{
vtkm::cont::ScopedRuntimeDeviceTracker tracker(tag,
vtkm::cont::RuntimeDeviceTrackerMode::Force);
VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == haveSupport, "");
verify_state(tag, force);
}
if (haveSupport)
{
vtkm::cont::ScopedRuntimeDeviceTracker tracker(tag,
vtkm::cont::RuntimeDeviceTrackerMode::Enable);
VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == haveSupport);
verify_state(tag, enable);
}
{
vtkm::cont::ScopedRuntimeDeviceTracker tracker(tag,
vtkm::cont::RuntimeDeviceTrackerMode::Disable);
VTKM_TEST_ASSERT(tracker.CanRunOn(tag) == false, "");
verify_state(tag, disable);
}
}
void VerifyScopedRuntimeDeviceTracker()
{
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID> all_off;
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID> all_on;
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID> defaults;
all_off.fill(false);
vtkm::cont::RuntimeDeviceInformation runtime;
auto& tracker = vtkm::cont::GetRuntimeDeviceTracker();
for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
{
auto deviceId = vtkm::cont::make_DeviceAdapterId(i);
defaults[static_cast<std::size_t>(i)] = tracker.CanRunOn(deviceId);
all_on[static_cast<std::size_t>(i)] = runtime.Exists(deviceId);
}
using SerialTag = ::vtkm::cont::DeviceAdapterTagSerial;
using OpenMPTag = ::vtkm::cont::DeviceAdapterTagOpenMP;
using TBBTag = ::vtkm::cont::DeviceAdapterTagTBB;
using CudaTag = ::vtkm::cont::DeviceAdapterTagCuda;
using AnyTag = ::vtkm::cont::DeviceAdapterTagAny;
//Verify that for each device adapter we compile code for, that it
//has valid runtime support.
verify_srdt_support(SerialTag(), all_off, all_on, defaults);
verify_srdt_support(OpenMPTag(), all_off, all_on, defaults);
verify_srdt_support(CudaTag(), all_off, all_on, defaults);
verify_srdt_support(TBBTag(), all_off, all_on, defaults);
// Verify that all the ScopedRuntimeDeviceTracker changes
// have been reverted
verify_state(AnyTag(), defaults);
verify_srdt_support(AnyTag(), all_on, all_on, all_off);
// Verify that all the ScopedRuntimeDeviceTracker changes
// have been reverted
verify_state(AnyTag(), defaults);
}
} // anonymous namespace
int UnitTestScopedRuntimeDeviceTracker(int argc, char* argv[])
{
return vtkm::cont::testing::Testing::Run(VerifyScopedRuntimeDeviceTracker, argc, argv);
}