Merge topic 'refactor-runtimedevicetracker' into release-2.0

c7a2a7b30 Refactor RuntimeDeviceTracker

Acked-by: Kitware Robot <kwrobot@kitware.com>
Merge-request: !2971
This commit is contained in:
Sujin Philip 2023-01-31 15:44:30 +00:00 committed by Kitware Robot
commit ecb6f8d6d4

@ -13,6 +13,7 @@
#include <vtkm/cont/ErrorBadValue.h>
#include <algorithm>
#include <array>
#include <map>
#include <mutex>
#include <sstream>
@ -28,39 +29,11 @@ namespace detail
struct RuntimeDeviceTrackerInternals
{
RuntimeDeviceTrackerInternals() = default;
RuntimeDeviceTrackerInternals(const RuntimeDeviceTrackerInternals* v) { this->CopyFrom(v); }
RuntimeDeviceTrackerInternals& operator=(const RuntimeDeviceTrackerInternals* v)
{
this->CopyFrom(v);
return *this;
}
bool GetRuntimeAllowed(std::size_t deviceId) const { return this->RuntimeAllowed[deviceId]; }
void SetRuntimeAllowed(std::size_t deviceId, bool flag) { this->RuntimeAllowed[deviceId] = flag; }
bool GetThreadFriendlyMemAlloc() const { return this->ThreadFriendlyMemAlloc; }
void SetThreadFriendlyMemAlloc(bool flag) { this->ThreadFriendlyMemAlloc = flag; }
void ResetRuntimeAllowed()
{
std::fill_n(this->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
}
void ResetRuntimeAllowed() { this->RuntimeAllowed.fill(false); }
void Reset() { this->ResetRuntimeAllowed(); }
private:
void CopyFrom(const RuntimeDeviceTrackerInternals* v)
{
std::copy(std::cbegin(v->RuntimeAllowed),
std::cend(v->RuntimeAllowed),
std::begin(this->RuntimeAllowed));
this->SetThreadFriendlyMemAlloc(v->GetThreadFriendlyMemAlloc());
}
bool RuntimeAllowed[VTKM_MAX_DEVICE_ADAPTER_ID];
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID> RuntimeAllowed;
bool ThreadFriendlyMemAlloc = false;
};
@ -99,7 +72,7 @@ bool RuntimeDeviceTracker::CanRunOn(vtkm::cont::DeviceAdapterId deviceId) const
{ //If at least a single device is enabled, than any device is enabled
for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
{
if (this->Internals->GetRuntimeAllowed(static_cast<std::size_t>(i)))
if (this->Internals->RuntimeAllowed[static_cast<std::size_t>(i)])
{
return true;
}
@ -109,14 +82,14 @@ bool RuntimeDeviceTracker::CanRunOn(vtkm::cont::DeviceAdapterId deviceId) const
else
{
this->CheckDevice(deviceId);
return this->Internals->GetRuntimeAllowed(deviceId.GetValue());
return this->Internals->RuntimeAllowed[static_cast<std::size_t>(deviceId.GetValue())];
}
}
VTKM_CONT
bool RuntimeDeviceTracker::GetThreadFriendlyMemAlloc() const
{
return this->Internals->GetThreadFriendlyMemAlloc();
return this->Internals->ThreadFriendlyMemAlloc;
}
VTKM_CONT
@ -124,13 +97,13 @@ void RuntimeDeviceTracker::SetDeviceState(vtkm::cont::DeviceAdapterId deviceId,
{
this->CheckDevice(deviceId);
this->Internals->SetRuntimeAllowed(deviceId.GetValue(), state);
this->Internals->RuntimeAllowed[static_cast<std::size_t>(deviceId.GetValue())] = state;
}
VTKM_CONT
void RuntimeDeviceTracker::SetThreadFriendlyMemAlloc(bool state)
{
this->Internals->SetThreadFriendlyMemAlloc(state);
this->Internals->ThreadFriendlyMemAlloc = state;
}
VTKM_CONT void RuntimeDeviceTracker::ResetDevice(vtkm::cont::DeviceAdapterId deviceId)
@ -163,7 +136,7 @@ void RuntimeDeviceTracker::Reset()
if (device.IsValueValid())
{
const bool state = runtimeDevice.Exists(device);
this->Internals->SetRuntimeAllowed(device.GetValue(), state);
this->Internals->RuntimeAllowed[static_cast<std::size_t>(device.GetValue())] = state;
}
}
this->LogEnabledDevices();
@ -203,14 +176,14 @@ void RuntimeDeviceTracker::ForceDevice(DeviceAdapterId deviceId)
}
this->Internals->ResetRuntimeAllowed();
this->Internals->SetRuntimeAllowed(deviceId.GetValue(), runtimeExists);
this->Internals->RuntimeAllowed[static_cast<std::size_t>(deviceId.GetValue())] = runtimeExists;
this->LogEnabledDevices();
}
}
VTKM_CONT void RuntimeDeviceTracker::CopyStateFrom(const vtkm::cont::RuntimeDeviceTracker& tracker)
{
*(this->Internals) = tracker.Internals;
*(this->Internals) = *tracker.Internals;
}
VTKM_CONT
@ -247,13 +220,19 @@ void RuntimeDeviceTracker::LogEnabledDevices() const
}
VTKM_CONT
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device,
RuntimeDeviceTrackerMode mode)
: RuntimeDeviceTracker(GetRuntimeDeviceTracker().Internals, false)
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
const vtkm::cont::RuntimeDeviceTracker& tracker)
: RuntimeDeviceTracker(tracker.Internals, false)
, SavedState(new detail::RuntimeDeviceTrackerInternals(*this->Internals))
{
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
}
VTKM_CONT
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device,
RuntimeDeviceTrackerMode mode)
: ScopedRuntimeDeviceTracker(GetRuntimeDeviceTracker())
{
if (mode == RuntimeDeviceTrackerMode::Force)
{
this->ForceDevice(device);
@ -273,11 +252,8 @@ ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
vtkm::cont::DeviceAdapterId device,
RuntimeDeviceTrackerMode mode,
const vtkm::cont::RuntimeDeviceTracker& tracker)
: RuntimeDeviceTracker(tracker.Internals, false)
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
: ScopedRuntimeDeviceTracker(tracker)
{
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
if (mode == RuntimeDeviceTrackerMode::Force)
{
this->ForceDevice(device);
@ -292,20 +268,11 @@ ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
}
}
VTKM_CONT
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
const vtkm::cont::RuntimeDeviceTracker& tracker)
: RuntimeDeviceTracker(tracker.Internals, false)
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
{
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
}
VTKM_CONT
ScopedRuntimeDeviceTracker::~ScopedRuntimeDeviceTracker()
{
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Leaving scoped runtime region");
*(this->Internals) = this->SavedState.get();
*(this->Internals) = *this->SavedState;
this->LogEnabledDevices();
}