Merge topic 'refactor-runtimedevicetracker'
c7a2a7b30 Refactor RuntimeDeviceTracker Acked-by: Kitware Robot <kwrobot@kitware.com> Merge-request: !2971
This commit is contained in:
commit
5fdb2cd770
@ -13,6 +13,7 @@
|
||||
#include <vtkm/cont/ErrorBadValue.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
@ -28,39 +29,11 @@ namespace detail
|
||||
|
||||
struct RuntimeDeviceTrackerInternals
|
||||
{
|
||||
RuntimeDeviceTrackerInternals() = default;
|
||||
|
||||
RuntimeDeviceTrackerInternals(const RuntimeDeviceTrackerInternals* v) { this->CopyFrom(v); }
|
||||
|
||||
RuntimeDeviceTrackerInternals& operator=(const RuntimeDeviceTrackerInternals* v)
|
||||
{
|
||||
this->CopyFrom(v);
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool GetRuntimeAllowed(std::size_t deviceId) const { return this->RuntimeAllowed[deviceId]; }
|
||||
void SetRuntimeAllowed(std::size_t deviceId, bool flag) { this->RuntimeAllowed[deviceId] = flag; }
|
||||
|
||||
bool GetThreadFriendlyMemAlloc() const { return this->ThreadFriendlyMemAlloc; }
|
||||
void SetThreadFriendlyMemAlloc(bool flag) { this->ThreadFriendlyMemAlloc = flag; }
|
||||
|
||||
void ResetRuntimeAllowed()
|
||||
{
|
||||
std::fill_n(this->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
|
||||
}
|
||||
void ResetRuntimeAllowed() { this->RuntimeAllowed.fill(false); }
|
||||
|
||||
void Reset() { this->ResetRuntimeAllowed(); }
|
||||
|
||||
private:
|
||||
void CopyFrom(const RuntimeDeviceTrackerInternals* v)
|
||||
{
|
||||
std::copy(std::cbegin(v->RuntimeAllowed),
|
||||
std::cend(v->RuntimeAllowed),
|
||||
std::begin(this->RuntimeAllowed));
|
||||
this->SetThreadFriendlyMemAlloc(v->GetThreadFriendlyMemAlloc());
|
||||
}
|
||||
|
||||
bool RuntimeAllowed[VTKM_MAX_DEVICE_ADAPTER_ID];
|
||||
std::array<bool, VTKM_MAX_DEVICE_ADAPTER_ID> RuntimeAllowed;
|
||||
bool ThreadFriendlyMemAlloc = false;
|
||||
};
|
||||
|
||||
@ -99,7 +72,7 @@ bool RuntimeDeviceTracker::CanRunOn(vtkm::cont::DeviceAdapterId deviceId) const
|
||||
{ //If at least a single device is enabled, than any device is enabled
|
||||
for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
|
||||
{
|
||||
if (this->Internals->GetRuntimeAllowed(static_cast<std::size_t>(i)))
|
||||
if (this->Internals->RuntimeAllowed[static_cast<std::size_t>(i)])
|
||||
{
|
||||
return true;
|
||||
}
|
||||
@ -109,14 +82,14 @@ bool RuntimeDeviceTracker::CanRunOn(vtkm::cont::DeviceAdapterId deviceId) const
|
||||
else
|
||||
{
|
||||
this->CheckDevice(deviceId);
|
||||
return this->Internals->GetRuntimeAllowed(deviceId.GetValue());
|
||||
return this->Internals->RuntimeAllowed[static_cast<std::size_t>(deviceId.GetValue())];
|
||||
}
|
||||
}
|
||||
|
||||
VTKM_CONT
|
||||
bool RuntimeDeviceTracker::GetThreadFriendlyMemAlloc() const
|
||||
{
|
||||
return this->Internals->GetThreadFriendlyMemAlloc();
|
||||
return this->Internals->ThreadFriendlyMemAlloc;
|
||||
}
|
||||
|
||||
VTKM_CONT
|
||||
@ -124,13 +97,13 @@ void RuntimeDeviceTracker::SetDeviceState(vtkm::cont::DeviceAdapterId deviceId,
|
||||
{
|
||||
this->CheckDevice(deviceId);
|
||||
|
||||
this->Internals->SetRuntimeAllowed(deviceId.GetValue(), state);
|
||||
this->Internals->RuntimeAllowed[static_cast<std::size_t>(deviceId.GetValue())] = state;
|
||||
}
|
||||
|
||||
VTKM_CONT
|
||||
void RuntimeDeviceTracker::SetThreadFriendlyMemAlloc(bool state)
|
||||
{
|
||||
this->Internals->SetThreadFriendlyMemAlloc(state);
|
||||
this->Internals->ThreadFriendlyMemAlloc = state;
|
||||
}
|
||||
|
||||
VTKM_CONT void RuntimeDeviceTracker::ResetDevice(vtkm::cont::DeviceAdapterId deviceId)
|
||||
@ -163,7 +136,7 @@ void RuntimeDeviceTracker::Reset()
|
||||
if (device.IsValueValid())
|
||||
{
|
||||
const bool state = runtimeDevice.Exists(device);
|
||||
this->Internals->SetRuntimeAllowed(device.GetValue(), state);
|
||||
this->Internals->RuntimeAllowed[static_cast<std::size_t>(device.GetValue())] = state;
|
||||
}
|
||||
}
|
||||
this->LogEnabledDevices();
|
||||
@ -203,14 +176,14 @@ void RuntimeDeviceTracker::ForceDevice(DeviceAdapterId deviceId)
|
||||
}
|
||||
|
||||
this->Internals->ResetRuntimeAllowed();
|
||||
this->Internals->SetRuntimeAllowed(deviceId.GetValue(), runtimeExists);
|
||||
this->Internals->RuntimeAllowed[static_cast<std::size_t>(deviceId.GetValue())] = runtimeExists;
|
||||
this->LogEnabledDevices();
|
||||
}
|
||||
}
|
||||
|
||||
VTKM_CONT void RuntimeDeviceTracker::CopyStateFrom(const vtkm::cont::RuntimeDeviceTracker& tracker)
|
||||
{
|
||||
*(this->Internals) = tracker.Internals;
|
||||
*(this->Internals) = *tracker.Internals;
|
||||
}
|
||||
|
||||
VTKM_CONT
|
||||
@ -247,13 +220,19 @@ void RuntimeDeviceTracker::LogEnabledDevices() const
|
||||
}
|
||||
|
||||
VTKM_CONT
|
||||
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device,
|
||||
RuntimeDeviceTrackerMode mode)
|
||||
: RuntimeDeviceTracker(GetRuntimeDeviceTracker().Internals, false)
|
||||
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
|
||||
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
|
||||
const vtkm::cont::RuntimeDeviceTracker& tracker)
|
||||
: RuntimeDeviceTracker(tracker.Internals, false)
|
||||
, SavedState(new detail::RuntimeDeviceTrackerInternals(*this->Internals))
|
||||
{
|
||||
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
|
||||
}
|
||||
|
||||
VTKM_CONT
|
||||
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device,
|
||||
RuntimeDeviceTrackerMode mode)
|
||||
: ScopedRuntimeDeviceTracker(GetRuntimeDeviceTracker())
|
||||
{
|
||||
if (mode == RuntimeDeviceTrackerMode::Force)
|
||||
{
|
||||
this->ForceDevice(device);
|
||||
@ -273,11 +252,8 @@ ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
|
||||
vtkm::cont::DeviceAdapterId device,
|
||||
RuntimeDeviceTrackerMode mode,
|
||||
const vtkm::cont::RuntimeDeviceTracker& tracker)
|
||||
: RuntimeDeviceTracker(tracker.Internals, false)
|
||||
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
|
||||
: ScopedRuntimeDeviceTracker(tracker)
|
||||
{
|
||||
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
|
||||
|
||||
if (mode == RuntimeDeviceTrackerMode::Force)
|
||||
{
|
||||
this->ForceDevice(device);
|
||||
@ -292,20 +268,11 @@ ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
|
||||
}
|
||||
}
|
||||
|
||||
VTKM_CONT
|
||||
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
|
||||
const vtkm::cont::RuntimeDeviceTracker& tracker)
|
||||
: RuntimeDeviceTracker(tracker.Internals, false)
|
||||
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
|
||||
{
|
||||
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
|
||||
}
|
||||
|
||||
VTKM_CONT
|
||||
ScopedRuntimeDeviceTracker::~ScopedRuntimeDeviceTracker()
|
||||
{
|
||||
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Leaving scoped runtime region");
|
||||
*(this->Internals) = this->SavedState.get();
|
||||
*(this->Internals) = *this->SavedState;
|
||||
|
||||
this->LogEnabledDevices();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user