ScopedRuntimeDeviceTracker requires a device to execute on when constructed.

To simplify using the ScopedRuntimeDeviceTracker it now takes the
device id you want to run on during construction.
This commit is contained in:
Robert Maynard 2019-05-20 15:04:09 -04:00
parent 4020f51988
commit fa03dc664a
8 changed files with 64 additions and 24 deletions

@ -39,6 +39,7 @@ the following block forces execution to only occur on
```cpp
{
vtkm::cont::DeviceAdapterTagCuda cuda;
auto& tracker = vtkm::cont::GetRuntimeDeviceTracker();
tracker->ForceDevice(cuda);
vtkm::worklet::Invoker invoke;
@ -55,8 +56,8 @@ correctly restore the threads `RuntimeDeviceTracker` state when `tracker`
goes out of scope.
```cpp
{
vtkm::cont::ScopedRuntimeDeviceTracker tracker;
tracker.ForceDevice(cuda);
vtkm::cont::DeviceAdapterTagCuda cuda;
vtkm::cont::ScopedRuntimeDeviceTracker tracker(cuda);
vtkm::worklet::Invoker invoke;
invoke(LightTask{}, input, output);
}

@ -312,8 +312,7 @@ public:
PortalConstType GetPortalConst() const
{
VTKM_ASSERT(this->Valid);
vtkm::cont::ScopedRuntimeDeviceTracker trackerScope;
trackerScope.ForceDevice(vtkm::cont::DeviceAdapterTagSerial());
vtkm::cont::ScopedRuntimeDeviceTracker trackerScope(vtkm::cont::DeviceAdapterTagSerial{});
return PortalConstType(this->Array.GetPortalConstControl(), this->Functor.PrepareForControl());
}
@ -404,8 +403,7 @@ public:
PortalType GetPortal()
{
VTKM_ASSERT(this->Valid);
vtkm::cont::ScopedRuntimeDeviceTracker trackerScope;
trackerScope.ForceDevice(vtkm::cont::DeviceAdapterTagSerial());
vtkm::cont::ScopedRuntimeDeviceTracker trackerScope(vtkm::cont::DeviceAdapterTagSerial{});
return PortalType(this->Array.GetPortalControl(),
this->Functor.PrepareForControl(),
this->InverseFunctor.PrepareForControl());
@ -415,8 +413,7 @@ public:
PortalConstType GetPortalConst() const
{
VTKM_ASSERT(this->Valid);
vtkm::cont::ScopedRuntimeDeviceTracker trackerScope;
trackerScope.ForceDevice(vtkm::cont::DeviceAdapterTagSerial());
vtkm::cont::ScopedRuntimeDeviceTracker trackerScope(vtkm::cont::DeviceAdapterTagSerial{});
return PortalConstType(this->Array.GetPortalConstControl(),
this->Functor.PrepareForControl(),
this->InverseFunctor.PrepareForControl());

@ -122,17 +122,37 @@ void RuntimeDeviceTracker::ForceDeviceImpl(vtkm::cont::DeviceAdapterId deviceId,
VTKM_CONT
void RuntimeDeviceTracker::ForceDevice(DeviceAdapterId deviceId)
{
vtkm::cont::RuntimeDeviceInformation runtimeDevice;
this->ForceDeviceImpl(deviceId, runtimeDevice.Exists(deviceId));
if (deviceId == vtkm::cont::DeviceAdapterTagAny{})
{
this->Reset();
}
else
{
vtkm::cont::RuntimeDeviceInformation runtimeDevice;
this->ForceDeviceImpl(deviceId, runtimeDevice.Exists(deviceId));
}
}
VTKM_CONT
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker()
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device)
: RuntimeDeviceTracker(GetRuntimeDeviceTracker().Internals, false)
, SavedState(new detail::RuntimeDeviceTrackerInternals())
{
std::copy_n(
this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, this->SavedState->RuntimeAllowed);
this->ForceDevice(device);
}
VTKM_CONT
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
vtkm::cont::DeviceAdapterId device,
const vtkm::cont::RuntimeDeviceTracker& tracker)
: RuntimeDeviceTracker(tracker.Internals, false)
, SavedState(new detail::RuntimeDeviceTrackerInternals())
{
std::copy_n(
this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, this->SavedState->RuntimeAllowed);
this->ForceDevice(device);
}
VTKM_CONT

@ -103,9 +103,9 @@ public:
/// The main intention of \c RuntimeDeviceTracker is to keep track of what
/// devices are working for VTK-m. However, it can also be used to turn
/// devices on and off. Use this method to disable all devices except one
/// to effectively force VTK-m to use that device. Use \c Reset restore
/// all devices to their default values. You can also use the \c DeepCopy
/// methods to save and restore the state.
/// to effectively force VTK-m to use that device. Either pass the
/// DeviceAdapterTagAny to this function or call \c Reset to restore
/// all devices to their default state.
///
/// This method will throw a \c ErrorBadValue if the given device does not
/// exist on the system.
@ -147,9 +147,35 @@ private:
///
struct VTKM_CONT_EXPORT ScopedRuntimeDeviceTracker : public vtkm::cont::RuntimeDeviceTracker
{
/// Construct a ScopedRuntimeDeviceTracker where the only active device
/// for the current thread is the one provided by the constructor. Passing
/// DeviceAdapterTagAny to this function will reset all devices to their
/// default state.
///
/// Constructor is not thread safe
VTKM_CONT ScopedRuntimeDeviceTracker();
VTKM_CONT ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device);
/// Construct a ScopedRuntimeDeviceTracker associated with the thread
/// associated with the provided tracker. The only active device
/// for this thread is the one provided by the constructor. Passing
/// DeviceAdapterTagAny to this function will reset all devices to their
/// default state.
///
/// Any modifications to the ScopedRuntimeDeviceTracker will effect what
/// ever thread the \c tracker is associated with, which might not be
/// the thread which ScopedRuntimeDeviceTracker was constructed on.
///
/// Constructor is not thread safe
VTKM_CONT ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device,
const vtkm::cont::RuntimeDeviceTracker& tracker);
/// Construct a ScopedRuntimeDeviceTracker associated with the thread
/// associated with the provided tracker.
///
/// Any modifications to the ScopedRuntimeDeviceTracker will effect what
/// ever thread the \c tracker is associated with, which might not be
/// the thread which ScopedRuntimeDeviceTracker was constructed on.
///
/// Constructor is not thread safe
VTKM_CONT ScopedRuntimeDeviceTracker(const vtkm::cont::RuntimeDeviceTracker& tracker);

@ -160,9 +160,8 @@ struct TransformExecObject : public vtkm::cont::ExecutionAndControlObjectBase
VTKM_CONT TransformExecObject(const FunctorType& functor)
{
// Need to make sure the serial device is supported, since that is what is used on the
// control side.
vtkm::cont::ScopedRuntimeDeviceTracker scopedTracker;
scopedTracker.ResetDevice(vtkm::cont::DeviceAdapterTagSerial());
// control side. Therefore we reset to all supported devices.
vtkm::cont::ScopedRuntimeDeviceTracker scopedTracker(vtkm::cont::DeviceAdapterTagAny{});
this->VirtualFunctor.Reset(new VirtualTransformFunctor<ValueType, FunctorType>(functor));
}

@ -157,8 +157,7 @@ void RunErrorTest(bool shouldFail, bool shouldThrow, bool shouldDisable)
bool threw = false;
bool disabled = false;
vtkm::cont::ScopedRuntimeDeviceTracker scopedTracker;
scopedTracker.ForceDevice(Device{});
vtkm::cont::ScopedRuntimeDeviceTracker scopedTracker(Device{});
try
{

@ -420,8 +420,7 @@ void Canvas::AddColorBar(const vtkm::Bounds& bounds,
vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::UInt8, 4>> colorMap;
{
vtkm::cont::ScopedRuntimeDeviceTracker tracker;
tracker.ForceDevice(vtkm::cont::DeviceAdapterTagSerial());
vtkm::cont::ScopedRuntimeDeviceTracker tracker(vtkm::cont::DeviceAdapterTagSerial{});
colorTable.Sample(static_cast<vtkm::Int32>(numSamples), colorMap);
}

@ -29,8 +29,7 @@ void Mapper::SetActiveColorTable(const vtkm::cont::ColorTable& colorTable)
vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::UInt8, 4>> temp;
{
vtkm::cont::ScopedRuntimeDeviceTracker tracker;
tracker.ForceDevice(vtkm::cont::DeviceAdapterTagSerial());
vtkm::cont::ScopedRuntimeDeviceTracker tracker(vtkm::cont::DeviceAdapterTagSerial{});
colorTable.Sample(1024, temp);
}