Merge topic 'refactor-runtimedevicetracker' into release-2.0

c7a2a7b30 Refactor RuntimeDeviceTracker

Acked-by: Kitware Robot <kwrobot@kitware.com>
Merge-request: !2971
This commit is contained in:
Sujin Philip 2023-01-31 15:44:30 +00:00 committed by Kitware Robot
commit ecb6f8d6d4

@ -13,6 +13,7 @@
#include <vtkm/cont/ErrorBadValue.h> #include <vtkm/cont/ErrorBadValue.h>
#include <algorithm> #include <algorithm>
#include <array>
#include <map> #include <map>
#include <mutex> #include <mutex>
#include <sstream> #include <sstream>
@ -28,39 +29,11 @@ namespace detail
struct RuntimeDeviceTrackerInternals struct RuntimeDeviceTrackerInternals
{ {
RuntimeDeviceTrackerInternals() = default; void ResetRuntimeAllowed() { this->RuntimeAllowed.fill(false); }
RuntimeDeviceTrackerInternals(const RuntimeDeviceTrackerInternals* v) { this->CopyFrom(v); }
RuntimeDeviceTrackerInternals& operator=(const RuntimeDeviceTrackerInternals* v)
{
this->CopyFrom(v);
return *this;
}
bool GetRuntimeAllowed(std::size_t deviceId) const { return this->RuntimeAllowed[deviceId]; }
void SetRuntimeAllowed(std::size_t deviceId, bool flag) { this->RuntimeAllowed[deviceId] = flag; }
bool GetThreadFriendlyMemAlloc() const { return this->ThreadFriendlyMemAlloc; }
void SetThreadFriendlyMemAlloc(bool flag) { this->ThreadFriendlyMemAlloc = flag; }
void ResetRuntimeAllowed()
{
std::fill_n(this->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
}
void Reset() { this->ResetRuntimeAllowed(); } void Reset() { this->ResetRuntimeAllowed(); }
private: std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID> RuntimeAllowed;
void CopyFrom(const RuntimeDeviceTrackerInternals* v)
{
std::copy(std::cbegin(v->RuntimeAllowed),
std::cend(v->RuntimeAllowed),
std::begin(this->RuntimeAllowed));
this->SetThreadFriendlyMemAlloc(v->GetThreadFriendlyMemAlloc());
}
bool RuntimeAllowed[VTKM_MAX_DEVICE_ADAPTER_ID];
bool ThreadFriendlyMemAlloc = false; bool ThreadFriendlyMemAlloc = false;
}; };
@ -99,7 +72,7 @@ bool RuntimeDeviceTracker::CanRunOn(vtkm::cont::DeviceAdapterId deviceId) const
{ //If at least a single device is enabled, than any device is enabled { //If at least a single device is enabled, than any device is enabled
for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i) for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
{ {
if (this->Internals->GetRuntimeAllowed(static_cast<std::size_t>(i))) if (this->Internals->RuntimeAllowed[static_cast<std::size_t>(i)])
{ {
return true; return true;
} }
@ -109,14 +82,14 @@ bool RuntimeDeviceTracker::CanRunOn(vtkm::cont::DeviceAdapterId deviceId) const
else else
{ {
this->CheckDevice(deviceId); this->CheckDevice(deviceId);
return this->Internals->GetRuntimeAllowed(deviceId.GetValue()); return this->Internals->RuntimeAllowed[static_cast<std::size_t>(deviceId.GetValue())];
} }
} }
VTKM_CONT VTKM_CONT
bool RuntimeDeviceTracker::GetThreadFriendlyMemAlloc() const bool RuntimeDeviceTracker::GetThreadFriendlyMemAlloc() const
{ {
return this->Internals->GetThreadFriendlyMemAlloc(); return this->Internals->ThreadFriendlyMemAlloc;
} }
VTKM_CONT VTKM_CONT
@ -124,13 +97,13 @@ void RuntimeDeviceTracker::SetDeviceState(vtkm::cont::DeviceAdapterId deviceId,
{ {
this->CheckDevice(deviceId); this->CheckDevice(deviceId);
this->Internals->SetRuntimeAllowed(deviceId.GetValue(), state); this->Internals->RuntimeAllowed[static_cast<std::size_t>(deviceId.GetValue())] = state;
} }
VTKM_CONT VTKM_CONT
void RuntimeDeviceTracker::SetThreadFriendlyMemAlloc(bool state) void RuntimeDeviceTracker::SetThreadFriendlyMemAlloc(bool state)
{ {
this->Internals->SetThreadFriendlyMemAlloc(state); this->Internals->ThreadFriendlyMemAlloc = state;
} }
VTKM_CONT void RuntimeDeviceTracker::ResetDevice(vtkm::cont::DeviceAdapterId deviceId) VTKM_CONT void RuntimeDeviceTracker::ResetDevice(vtkm::cont::DeviceAdapterId deviceId)
@ -163,7 +136,7 @@ void RuntimeDeviceTracker::Reset()
if (device.IsValueValid()) if (device.IsValueValid())
{ {
const bool state = runtimeDevice.Exists(device); const bool state = runtimeDevice.Exists(device);
this->Internals->SetRuntimeAllowed(device.GetValue(), state); this->Internals->RuntimeAllowed[static_cast<std::size_t>(device.GetValue())] = state;
} }
} }
this->LogEnabledDevices(); this->LogEnabledDevices();
@ -203,14 +176,14 @@ void RuntimeDeviceTracker::ForceDevice(DeviceAdapterId deviceId)
} }
this->Internals->ResetRuntimeAllowed(); this->Internals->ResetRuntimeAllowed();
this->Internals->SetRuntimeAllowed(deviceId.GetValue(), runtimeExists); this->Internals->RuntimeAllowed[static_cast<std::size_t>(deviceId.GetValue())] = runtimeExists;
this->LogEnabledDevices(); this->LogEnabledDevices();
} }
} }
VTKM_CONT void RuntimeDeviceTracker::CopyStateFrom(const vtkm::cont::RuntimeDeviceTracker& tracker) VTKM_CONT void RuntimeDeviceTracker::CopyStateFrom(const vtkm::cont::RuntimeDeviceTracker& tracker)
{ {
*(this->Internals) = tracker.Internals; *(this->Internals) = *tracker.Internals;
} }
VTKM_CONT VTKM_CONT
@ -247,13 +220,19 @@ void RuntimeDeviceTracker::LogEnabledDevices() const
} }
VTKM_CONT VTKM_CONT
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device, ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
RuntimeDeviceTrackerMode mode) const vtkm::cont::RuntimeDeviceTracker& tracker)
: RuntimeDeviceTracker(GetRuntimeDeviceTracker().Internals, false) : RuntimeDeviceTracker(tracker.Internals, false)
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals)) , SavedState(new detail::RuntimeDeviceTrackerInternals(*this->Internals))
{ {
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region"); VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
}
VTKM_CONT
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device,
RuntimeDeviceTrackerMode mode)
: ScopedRuntimeDeviceTracker(GetRuntimeDeviceTracker())
{
if (mode == RuntimeDeviceTrackerMode::Force) if (mode == RuntimeDeviceTrackerMode::Force)
{ {
this->ForceDevice(device); this->ForceDevice(device);
@ -273,11 +252,8 @@ ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
vtkm::cont::DeviceAdapterId device, vtkm::cont::DeviceAdapterId device,
RuntimeDeviceTrackerMode mode, RuntimeDeviceTrackerMode mode,
const vtkm::cont::RuntimeDeviceTracker& tracker) const vtkm::cont::RuntimeDeviceTracker& tracker)
: RuntimeDeviceTracker(tracker.Internals, false) : ScopedRuntimeDeviceTracker(tracker)
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
{ {
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
if (mode == RuntimeDeviceTrackerMode::Force) if (mode == RuntimeDeviceTrackerMode::Force)
{ {
this->ForceDevice(device); this->ForceDevice(device);
@ -292,20 +268,11 @@ ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
} }
} }
VTKM_CONT
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
const vtkm::cont::RuntimeDeviceTracker& tracker)
: RuntimeDeviceTracker(tracker.Internals, false)
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
{
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
}
VTKM_CONT VTKM_CONT
ScopedRuntimeDeviceTracker::~ScopedRuntimeDeviceTracker() ScopedRuntimeDeviceTracker::~ScopedRuntimeDeviceTracker()
{ {
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Leaving scoped runtime region"); VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Leaving scoped runtime region");
*(this->Internals) = this->SavedState.get(); *(this->Internals) = *this->SavedState;
this->LogEnabledDevices(); this->LogEnabledDevices();
} }