Merge topic 'refactor-runtimedevicetracker' into release-2.0
c7a2a7b30 Refactor RuntimeDeviceTracker Acked-by: Kitware Robot <kwrobot@kitware.com> Merge-request: !2971
This commit is contained in:
commit
ecb6f8d6d4
@ -13,6 +13,7 @@
|
|||||||
#include <vtkm/cont/ErrorBadValue.h>
|
#include <vtkm/cont/ErrorBadValue.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <array>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
@ -28,39 +29,11 @@ namespace detail
|
|||||||
|
|
||||||
struct RuntimeDeviceTrackerInternals
|
struct RuntimeDeviceTrackerInternals
|
||||||
{
|
{
|
||||||
RuntimeDeviceTrackerInternals() = default;
|
void ResetRuntimeAllowed() { this->RuntimeAllowed.fill(false); }
|
||||||
|
|
||||||
RuntimeDeviceTrackerInternals(const RuntimeDeviceTrackerInternals* v) { this->CopyFrom(v); }
|
|
||||||
|
|
||||||
RuntimeDeviceTrackerInternals& operator=(const RuntimeDeviceTrackerInternals* v)
|
|
||||||
{
|
|
||||||
this->CopyFrom(v);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool GetRuntimeAllowed(std::size_t deviceId) const { return this->RuntimeAllowed[deviceId]; }
|
|
||||||
void SetRuntimeAllowed(std::size_t deviceId, bool flag) { this->RuntimeAllowed[deviceId] = flag; }
|
|
||||||
|
|
||||||
bool GetThreadFriendlyMemAlloc() const { return this->ThreadFriendlyMemAlloc; }
|
|
||||||
void SetThreadFriendlyMemAlloc(bool flag) { this->ThreadFriendlyMemAlloc = flag; }
|
|
||||||
|
|
||||||
void ResetRuntimeAllowed()
|
|
||||||
{
|
|
||||||
std::fill_n(this->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Reset() { this->ResetRuntimeAllowed(); }
|
void Reset() { this->ResetRuntimeAllowed(); }
|
||||||
|
|
||||||
private:
|
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID> RuntimeAllowed;
|
||||||
void CopyFrom(const RuntimeDeviceTrackerInternals* v)
|
|
||||||
{
|
|
||||||
std::copy(std::cbegin(v->RuntimeAllowed),
|
|
||||||
std::cend(v->RuntimeAllowed),
|
|
||||||
std::begin(this->RuntimeAllowed));
|
|
||||||
this->SetThreadFriendlyMemAlloc(v->GetThreadFriendlyMemAlloc());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool RuntimeAllowed[VTKM_MAX_DEVICE_ADAPTER_ID];
|
|
||||||
bool ThreadFriendlyMemAlloc = false;
|
bool ThreadFriendlyMemAlloc = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -99,7 +72,7 @@ bool RuntimeDeviceTracker::CanRunOn(vtkm::cont::DeviceAdapterId deviceId) const
|
|||||||
{ //If at least a single device is enabled, than any device is enabled
|
{ //If at least a single device is enabled, than any device is enabled
|
||||||
for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
|
for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
|
||||||
{
|
{
|
||||||
if (this->Internals->GetRuntimeAllowed(static_cast<std::size_t>(i)))
|
if (this->Internals->RuntimeAllowed[static_cast<std::size_t>(i)])
|
||||||
{
|
{
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -109,14 +82,14 @@ bool RuntimeDeviceTracker::CanRunOn(vtkm::cont::DeviceAdapterId deviceId) const
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
this->CheckDevice(deviceId);
|
this->CheckDevice(deviceId);
|
||||||
return this->Internals->GetRuntimeAllowed(deviceId.GetValue());
|
return this->Internals->RuntimeAllowed[static_cast<std::size_t>(deviceId.GetValue())];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
VTKM_CONT
|
VTKM_CONT
|
||||||
bool RuntimeDeviceTracker::GetThreadFriendlyMemAlloc() const
|
bool RuntimeDeviceTracker::GetThreadFriendlyMemAlloc() const
|
||||||
{
|
{
|
||||||
return this->Internals->GetThreadFriendlyMemAlloc();
|
return this->Internals->ThreadFriendlyMemAlloc;
|
||||||
}
|
}
|
||||||
|
|
||||||
VTKM_CONT
|
VTKM_CONT
|
||||||
@ -124,13 +97,13 @@ void RuntimeDeviceTracker::SetDeviceState(vtkm::cont::DeviceAdapterId deviceId,
|
|||||||
{
|
{
|
||||||
this->CheckDevice(deviceId);
|
this->CheckDevice(deviceId);
|
||||||
|
|
||||||
this->Internals->SetRuntimeAllowed(deviceId.GetValue(), state);
|
this->Internals->RuntimeAllowed[static_cast<std::size_t>(deviceId.GetValue())] = state;
|
||||||
}
|
}
|
||||||
|
|
||||||
VTKM_CONT
|
VTKM_CONT
|
||||||
void RuntimeDeviceTracker::SetThreadFriendlyMemAlloc(bool state)
|
void RuntimeDeviceTracker::SetThreadFriendlyMemAlloc(bool state)
|
||||||
{
|
{
|
||||||
this->Internals->SetThreadFriendlyMemAlloc(state);
|
this->Internals->ThreadFriendlyMemAlloc = state;
|
||||||
}
|
}
|
||||||
|
|
||||||
VTKM_CONT void RuntimeDeviceTracker::ResetDevice(vtkm::cont::DeviceAdapterId deviceId)
|
VTKM_CONT void RuntimeDeviceTracker::ResetDevice(vtkm::cont::DeviceAdapterId deviceId)
|
||||||
@ -163,7 +136,7 @@ void RuntimeDeviceTracker::Reset()
|
|||||||
if (device.IsValueValid())
|
if (device.IsValueValid())
|
||||||
{
|
{
|
||||||
const bool state = runtimeDevice.Exists(device);
|
const bool state = runtimeDevice.Exists(device);
|
||||||
this->Internals->SetRuntimeAllowed(device.GetValue(), state);
|
this->Internals->RuntimeAllowed[static_cast<std::size_t>(device.GetValue())] = state;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this->LogEnabledDevices();
|
this->LogEnabledDevices();
|
||||||
@ -203,14 +176,14 @@ void RuntimeDeviceTracker::ForceDevice(DeviceAdapterId deviceId)
|
|||||||
}
|
}
|
||||||
|
|
||||||
this->Internals->ResetRuntimeAllowed();
|
this->Internals->ResetRuntimeAllowed();
|
||||||
this->Internals->SetRuntimeAllowed(deviceId.GetValue(), runtimeExists);
|
this->Internals->RuntimeAllowed[static_cast<std::size_t>(deviceId.GetValue())] = runtimeExists;
|
||||||
this->LogEnabledDevices();
|
this->LogEnabledDevices();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
VTKM_CONT void RuntimeDeviceTracker::CopyStateFrom(const vtkm::cont::RuntimeDeviceTracker& tracker)
|
VTKM_CONT void RuntimeDeviceTracker::CopyStateFrom(const vtkm::cont::RuntimeDeviceTracker& tracker)
|
||||||
{
|
{
|
||||||
*(this->Internals) = tracker.Internals;
|
*(this->Internals) = *tracker.Internals;
|
||||||
}
|
}
|
||||||
|
|
||||||
VTKM_CONT
|
VTKM_CONT
|
||||||
@ -247,13 +220,19 @@ void RuntimeDeviceTracker::LogEnabledDevices() const
|
|||||||
}
|
}
|
||||||
|
|
||||||
VTKM_CONT
|
VTKM_CONT
|
||||||
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device,
|
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
|
||||||
RuntimeDeviceTrackerMode mode)
|
const vtkm::cont::RuntimeDeviceTracker& tracker)
|
||||||
: RuntimeDeviceTracker(GetRuntimeDeviceTracker().Internals, false)
|
: RuntimeDeviceTracker(tracker.Internals, false)
|
||||||
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
|
, SavedState(new detail::RuntimeDeviceTrackerInternals(*this->Internals))
|
||||||
{
|
{
|
||||||
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
|
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
|
||||||
|
}
|
||||||
|
|
||||||
|
VTKM_CONT
|
||||||
|
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device,
|
||||||
|
RuntimeDeviceTrackerMode mode)
|
||||||
|
: ScopedRuntimeDeviceTracker(GetRuntimeDeviceTracker())
|
||||||
|
{
|
||||||
if (mode == RuntimeDeviceTrackerMode::Force)
|
if (mode == RuntimeDeviceTrackerMode::Force)
|
||||||
{
|
{
|
||||||
this->ForceDevice(device);
|
this->ForceDevice(device);
|
||||||
@ -273,11 +252,8 @@ ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
|
|||||||
vtkm::cont::DeviceAdapterId device,
|
vtkm::cont::DeviceAdapterId device,
|
||||||
RuntimeDeviceTrackerMode mode,
|
RuntimeDeviceTrackerMode mode,
|
||||||
const vtkm::cont::RuntimeDeviceTracker& tracker)
|
const vtkm::cont::RuntimeDeviceTracker& tracker)
|
||||||
: RuntimeDeviceTracker(tracker.Internals, false)
|
: ScopedRuntimeDeviceTracker(tracker)
|
||||||
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
|
|
||||||
{
|
{
|
||||||
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
|
|
||||||
|
|
||||||
if (mode == RuntimeDeviceTrackerMode::Force)
|
if (mode == RuntimeDeviceTrackerMode::Force)
|
||||||
{
|
{
|
||||||
this->ForceDevice(device);
|
this->ForceDevice(device);
|
||||||
@ -292,20 +268,11 @@ ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
VTKM_CONT
|
|
||||||
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
|
|
||||||
const vtkm::cont::RuntimeDeviceTracker& tracker)
|
|
||||||
: RuntimeDeviceTracker(tracker.Internals, false)
|
|
||||||
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
|
|
||||||
{
|
|
||||||
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
|
|
||||||
}
|
|
||||||
|
|
||||||
VTKM_CONT
|
VTKM_CONT
|
||||||
ScopedRuntimeDeviceTracker::~ScopedRuntimeDeviceTracker()
|
ScopedRuntimeDeviceTracker::~ScopedRuntimeDeviceTracker()
|
||||||
{
|
{
|
||||||
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Leaving scoped runtime region");
|
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Leaving scoped runtime region");
|
||||||
*(this->Internals) = this->SavedState.get();
|
*(this->Internals) = *this->SavedState;
|
||||||
|
|
||||||
this->LogEnabledDevices();
|
this->LogEnabledDevices();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user