Add option to thread-efficient mem alloc

This commit is contained in:
Dave Pugmire 2022-08-12 15:10:38 -04:00
parent 3920806a66
commit 9f4b786a7f
2 changed files with 101 additions and 24 deletions

@ -28,7 +28,57 @@ namespace detail
struct RuntimeDeviceTrackerInternals
{
RuntimeDeviceTrackerInternals() = default;
RuntimeDeviceTrackerInternals(const RuntimeDeviceTrackerInternals* v) { this->CopyFrom(v); }
RuntimeDeviceTrackerInternals& operator=(const RuntimeDeviceTrackerInternals* v)
{
this->CopyFrom(v);
return *this;
}
bool GetRuntimeAllowed(std::size_t deviceId) const { return this->RuntimeAllowed[deviceId]; }
void SetRuntimeAllowed(std::size_t deviceId, bool flag) { this->RuntimeAllowed[deviceId] = flag; }
bool GetThreadFriendlyMemAlloc(std::size_t deviceId) const
{
return this->ThreadFriendlyMemAlloc[deviceId];
}
void SetThreadFriendlyMemAlloc(std::size_t deviceId, bool flag)
{
this->ThreadFriendlyMemAlloc[deviceId] = flag;
}
void ResetRuntimeAllowed()
{
std::fill_n(this->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
}
void ResetThreadFriendlyMemAlloc()
{
std::fill_n(this->ThreadFriendlyMemAlloc, VTKM_MAX_DEVICE_ADAPTER_ID, false);
}
void Reset()
{
this->ResetRuntimeAllowed();
this->ResetThreadFriendlyMemAlloc();
}
private:
void CopyFrom(const RuntimeDeviceTrackerInternals* v)
{
std::copy(std::cbegin(v->RuntimeAllowed),
std::cend(v->RuntimeAllowed),
std::begin(this->RuntimeAllowed));
std::copy(std::cbegin(v->ThreadFriendlyMemAlloc),
std::cend(v->ThreadFriendlyMemAlloc),
std::begin(this->ThreadFriendlyMemAlloc));
}
bool RuntimeAllowed[VTKM_MAX_DEVICE_ADAPTER_ID];
bool ThreadFriendlyMemAlloc[VTKM_MAX_DEVICE_ADAPTER_ID];
};
}
@ -65,7 +115,7 @@ bool RuntimeDeviceTracker::CanRunOn(vtkm::cont::DeviceAdapterId deviceId) const
{ //If at least a single device is enabled, than any device is enabled
for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
{
if (this->Internals->RuntimeAllowed[static_cast<std::size_t>(i)])
if (this->Internals->GetRuntimeAllowed(static_cast<std::size_t>(i)))
{
return true;
}
@ -75,7 +125,28 @@ bool RuntimeDeviceTracker::CanRunOn(vtkm::cont::DeviceAdapterId deviceId) const
else
{
this->CheckDevice(deviceId);
return this->Internals->RuntimeAllowed[deviceId.GetValue()];
return this->Internals->GetRuntimeAllowed(deviceId.GetValue());
}
}
VTKM_CONT
bool RuntimeDeviceTracker::GetThreadFriendlyMemAlloc(vtkm::cont::DeviceAdapterId deviceId) const
{
if (deviceId == vtkm::cont::DeviceAdapterTagAny{})
{ //If at least a single device is enabled, than any device is enabled
for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
{
if (this->Internals->GetThreadFriendlyMemAlloc(static_cast<std::size_t>(i)))
{
return true;
}
}
return false;
}
else
{
this->CheckDevice(deviceId);
return this->Internals->GetThreadFriendlyMemAlloc(deviceId.GetValue());
}
}
@ -84,9 +155,17 @@ void RuntimeDeviceTracker::SetDeviceState(vtkm::cont::DeviceAdapterId deviceId,
{
this->CheckDevice(deviceId);
this->Internals->RuntimeAllowed[deviceId.GetValue()] = state;
this->Internals->SetRuntimeAllowed(deviceId.GetValue(), state);
}
VTKM_CONT
void RuntimeDeviceTracker::SetThreadFriendlyMemAlloc(vtkm::cont::DeviceAdapterId deviceId,
bool state)
{
this->CheckDevice(deviceId);
this->Internals->SetThreadFriendlyMemAlloc(deviceId.GetValue(), state);
}
VTKM_CONT void RuntimeDeviceTracker::ResetDevice(vtkm::cont::DeviceAdapterId deviceId)
{
@ -106,7 +185,7 @@ VTKM_CONT void RuntimeDeviceTracker::ResetDevice(vtkm::cont::DeviceAdapterId dev
VTKM_CONT
void RuntimeDeviceTracker::Reset()
{
std::fill_n(this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
this->Internals->Reset();
// We use this instead of calling CheckDevice/SetDeviceState so that
// when we use logging we get better messages stating we are reseting
@ -118,7 +197,7 @@ void RuntimeDeviceTracker::Reset()
if (device.IsValueValid())
{
const bool state = runtimeDevice.Exists(device);
this->Internals->RuntimeAllowed[device.GetValue()] = state;
this->Internals->SetRuntimeAllowed(device.GetValue(), state);
}
}
this->LogEnabledDevices();
@ -128,7 +207,7 @@ VTKM_CONT void RuntimeDeviceTracker::DisableDevice(vtkm::cont::DeviceAdapterId d
{
if (deviceId == vtkm::cont::DeviceAdapterTagAny{})
{
std::fill_n(this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
this->Internals->ResetRuntimeAllowed();
}
else
{
@ -157,18 +236,15 @@ void RuntimeDeviceTracker::ForceDevice(DeviceAdapterId deviceId)
throw vtkm::cont::ErrorBadValue(message.str());
}
std::fill_n(this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
this->Internals->RuntimeAllowed[deviceId.GetValue()] = runtimeExists;
this->Internals->ResetRuntimeAllowed();
this->Internals->SetRuntimeAllowed(deviceId.GetValue(), runtimeExists);
this->LogEnabledDevices();
}
}
VTKM_CONT void RuntimeDeviceTracker::CopyStateFrom(const vtkm::cont::RuntimeDeviceTracker& tracker)
{
std::copy(std::cbegin(tracker.Internals->RuntimeAllowed),
std::cend(tracker.Internals->RuntimeAllowed),
std::begin(this->Internals->RuntimeAllowed));
*(this->Internals) = tracker.Internals;
}
VTKM_CONT
@ -208,11 +284,9 @@ VTKM_CONT
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device,
RuntimeDeviceTrackerMode mode)
: RuntimeDeviceTracker(GetRuntimeDeviceTracker().Internals, false)
, SavedState(new detail::RuntimeDeviceTrackerInternals())
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
{
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
std::copy_n(
this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, this->SavedState->RuntimeAllowed);
if (mode == RuntimeDeviceTrackerMode::Force)
{
@ -234,11 +308,10 @@ ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
RuntimeDeviceTrackerMode mode,
const vtkm::cont::RuntimeDeviceTracker& tracker)
: RuntimeDeviceTracker(tracker.Internals, false)
, SavedState(new detail::RuntimeDeviceTrackerInternals())
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
{
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
std::copy_n(
this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, this->SavedState->RuntimeAllowed);
if (mode == RuntimeDeviceTrackerMode::Force)
{
this->ForceDevice(device);
@ -257,19 +330,17 @@ VTKM_CONT
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
const vtkm::cont::RuntimeDeviceTracker& tracker)
: RuntimeDeviceTracker(tracker.Internals, false)
, SavedState(new detail::RuntimeDeviceTrackerInternals())
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
{
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
std::copy_n(
this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, this->SavedState->RuntimeAllowed);
}
VTKM_CONT
ScopedRuntimeDeviceTracker::~ScopedRuntimeDeviceTracker()
{
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Leaving scoped runtime region");
std::copy_n(
this->SavedState->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, this->Internals->RuntimeAllowed);
*(this->Internals) = this->SavedState.get();
this->LogEnabledDevices();
}

@ -110,7 +110,13 @@ public:
///
VTKM_CONT void ForceDevice(DeviceAdapterId deviceId);
/// \brief Copyies the state from the given device.
/// \brief Get/Set use of thread-friendly memory allocation for a device.
///
///
VTKM_CONT bool GetThreadFriendlyMemAlloc(DeviceAdapterId deviceId) const;
VTKM_CONT void SetThreadFriendlyMemAlloc(vtkm::cont::DeviceAdapterId deviceId, bool state);
/// \brief Copies the state from the given device.
///
/// This is a convenient way to allow the `RuntimeDeviceTracker` on one thread
/// copy the behavior from another thread.