mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-09-08 21:33:55 +00:00
Add option to thread-efficient mem alloc
This commit is contained in:
parent
3920806a66
commit
9f4b786a7f
@ -28,7 +28,57 @@ namespace detail
|
||||
|
||||
struct RuntimeDeviceTrackerInternals
|
||||
{
|
||||
RuntimeDeviceTrackerInternals() = default;
|
||||
|
||||
RuntimeDeviceTrackerInternals(const RuntimeDeviceTrackerInternals* v) { this->CopyFrom(v); }
|
||||
|
||||
RuntimeDeviceTrackerInternals& operator=(const RuntimeDeviceTrackerInternals* v)
|
||||
{
|
||||
this->CopyFrom(v);
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool GetRuntimeAllowed(std::size_t deviceId) const { return this->RuntimeAllowed[deviceId]; }
|
||||
void SetRuntimeAllowed(std::size_t deviceId, bool flag) { this->RuntimeAllowed[deviceId] = flag; }
|
||||
|
||||
bool GetThreadFriendlyMemAlloc(std::size_t deviceId) const
|
||||
{
|
||||
return this->ThreadFriendlyMemAlloc[deviceId];
|
||||
}
|
||||
void SetThreadFriendlyMemAlloc(std::size_t deviceId, bool flag)
|
||||
{
|
||||
this->ThreadFriendlyMemAlloc[deviceId] = flag;
|
||||
}
|
||||
|
||||
void ResetRuntimeAllowed()
|
||||
{
|
||||
std::fill_n(this->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
|
||||
}
|
||||
|
||||
void ResetThreadFriendlyMemAlloc()
|
||||
{
|
||||
std::fill_n(this->ThreadFriendlyMemAlloc, VTKM_MAX_DEVICE_ADAPTER_ID, false);
|
||||
}
|
||||
|
||||
void Reset()
|
||||
{
|
||||
this->ResetRuntimeAllowed();
|
||||
this->ResetThreadFriendlyMemAlloc();
|
||||
}
|
||||
|
||||
private:
|
||||
void CopyFrom(const RuntimeDeviceTrackerInternals* v)
|
||||
{
|
||||
std::copy(std::cbegin(v->RuntimeAllowed),
|
||||
std::cend(v->RuntimeAllowed),
|
||||
std::begin(this->RuntimeAllowed));
|
||||
std::copy(std::cbegin(v->ThreadFriendlyMemAlloc),
|
||||
std::cend(v->ThreadFriendlyMemAlloc),
|
||||
std::begin(this->ThreadFriendlyMemAlloc));
|
||||
}
|
||||
|
||||
bool RuntimeAllowed[VTKM_MAX_DEVICE_ADAPTER_ID];
|
||||
bool ThreadFriendlyMemAlloc[VTKM_MAX_DEVICE_ADAPTER_ID];
|
||||
};
|
||||
}
|
||||
|
||||
@ -65,7 +115,7 @@ bool RuntimeDeviceTracker::CanRunOn(vtkm::cont::DeviceAdapterId deviceId) const
|
||||
{ //If at least a single device is enabled, than any device is enabled
|
||||
for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
|
||||
{
|
||||
if (this->Internals->RuntimeAllowed[static_cast<std::size_t>(i)])
|
||||
if (this->Internals->GetRuntimeAllowed(static_cast<std::size_t>(i)))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
@ -75,7 +125,28 @@ bool RuntimeDeviceTracker::CanRunOn(vtkm::cont::DeviceAdapterId deviceId) const
|
||||
else
|
||||
{
|
||||
this->CheckDevice(deviceId);
|
||||
return this->Internals->RuntimeAllowed[deviceId.GetValue()];
|
||||
return this->Internals->GetRuntimeAllowed(deviceId.GetValue());
|
||||
}
|
||||
}
|
||||
|
||||
VTKM_CONT
|
||||
bool RuntimeDeviceTracker::GetThreadFriendlyMemAlloc(vtkm::cont::DeviceAdapterId deviceId) const
|
||||
{
|
||||
if (deviceId == vtkm::cont::DeviceAdapterTagAny{})
|
||||
{ //If at least a single device is enabled, than any device is enabled
|
||||
for (vtkm::Int8 i = 1; i < VTKM_MAX_DEVICE_ADAPTER_ID; ++i)
|
||||
{
|
||||
if (this->Internals->GetThreadFriendlyMemAlloc(static_cast<std::size_t>(i)))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
this->CheckDevice(deviceId);
|
||||
return this->Internals->GetThreadFriendlyMemAlloc(deviceId.GetValue());
|
||||
}
|
||||
}
|
||||
|
||||
@ -84,9 +155,17 @@ void RuntimeDeviceTracker::SetDeviceState(vtkm::cont::DeviceAdapterId deviceId,
|
||||
{
|
||||
this->CheckDevice(deviceId);
|
||||
|
||||
this->Internals->RuntimeAllowed[deviceId.GetValue()] = state;
|
||||
this->Internals->SetRuntimeAllowed(deviceId.GetValue(), state);
|
||||
}
|
||||
|
||||
VTKM_CONT
|
||||
void RuntimeDeviceTracker::SetThreadFriendlyMemAlloc(vtkm::cont::DeviceAdapterId deviceId,
|
||||
bool state)
|
||||
{
|
||||
this->CheckDevice(deviceId);
|
||||
|
||||
this->Internals->SetThreadFriendlyMemAlloc(deviceId.GetValue(), state);
|
||||
}
|
||||
|
||||
VTKM_CONT void RuntimeDeviceTracker::ResetDevice(vtkm::cont::DeviceAdapterId deviceId)
|
||||
{
|
||||
@ -106,7 +185,7 @@ VTKM_CONT void RuntimeDeviceTracker::ResetDevice(vtkm::cont::DeviceAdapterId dev
|
||||
VTKM_CONT
|
||||
void RuntimeDeviceTracker::Reset()
|
||||
{
|
||||
std::fill_n(this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
|
||||
this->Internals->Reset();
|
||||
|
||||
// We use this instead of calling CheckDevice/SetDeviceState so that
|
||||
// when we use logging we get better messages stating we are reseting
|
||||
@ -118,7 +197,7 @@ void RuntimeDeviceTracker::Reset()
|
||||
if (device.IsValueValid())
|
||||
{
|
||||
const bool state = runtimeDevice.Exists(device);
|
||||
this->Internals->RuntimeAllowed[device.GetValue()] = state;
|
||||
this->Internals->SetRuntimeAllowed(device.GetValue(), state);
|
||||
}
|
||||
}
|
||||
this->LogEnabledDevices();
|
||||
@ -128,7 +207,7 @@ VTKM_CONT void RuntimeDeviceTracker::DisableDevice(vtkm::cont::DeviceAdapterId d
|
||||
{
|
||||
if (deviceId == vtkm::cont::DeviceAdapterTagAny{})
|
||||
{
|
||||
std::fill_n(this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
|
||||
this->Internals->ResetRuntimeAllowed();
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -157,18 +236,15 @@ void RuntimeDeviceTracker::ForceDevice(DeviceAdapterId deviceId)
|
||||
throw vtkm::cont::ErrorBadValue(message.str());
|
||||
}
|
||||
|
||||
std::fill_n(this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, false);
|
||||
|
||||
this->Internals->RuntimeAllowed[deviceId.GetValue()] = runtimeExists;
|
||||
this->Internals->ResetRuntimeAllowed();
|
||||
this->Internals->SetRuntimeAllowed(deviceId.GetValue(), runtimeExists);
|
||||
this->LogEnabledDevices();
|
||||
}
|
||||
}
|
||||
|
||||
VTKM_CONT void RuntimeDeviceTracker::CopyStateFrom(const vtkm::cont::RuntimeDeviceTracker& tracker)
|
||||
{
|
||||
std::copy(std::cbegin(tracker.Internals->RuntimeAllowed),
|
||||
std::cend(tracker.Internals->RuntimeAllowed),
|
||||
std::begin(this->Internals->RuntimeAllowed));
|
||||
*(this->Internals) = tracker.Internals;
|
||||
}
|
||||
|
||||
VTKM_CONT
|
||||
@ -208,11 +284,9 @@ VTKM_CONT
|
||||
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device,
|
||||
RuntimeDeviceTrackerMode mode)
|
||||
: RuntimeDeviceTracker(GetRuntimeDeviceTracker().Internals, false)
|
||||
, SavedState(new detail::RuntimeDeviceTrackerInternals())
|
||||
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
|
||||
{
|
||||
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
|
||||
std::copy_n(
|
||||
this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, this->SavedState->RuntimeAllowed);
|
||||
|
||||
if (mode == RuntimeDeviceTrackerMode::Force)
|
||||
{
|
||||
@ -234,11 +308,10 @@ ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
|
||||
RuntimeDeviceTrackerMode mode,
|
||||
const vtkm::cont::RuntimeDeviceTracker& tracker)
|
||||
: RuntimeDeviceTracker(tracker.Internals, false)
|
||||
, SavedState(new detail::RuntimeDeviceTrackerInternals())
|
||||
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
|
||||
{
|
||||
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
|
||||
std::copy_n(
|
||||
this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, this->SavedState->RuntimeAllowed);
|
||||
|
||||
if (mode == RuntimeDeviceTrackerMode::Force)
|
||||
{
|
||||
this->ForceDevice(device);
|
||||
@ -257,19 +330,17 @@ VTKM_CONT
|
||||
ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(
|
||||
const vtkm::cont::RuntimeDeviceTracker& tracker)
|
||||
: RuntimeDeviceTracker(tracker.Internals, false)
|
||||
, SavedState(new detail::RuntimeDeviceTrackerInternals())
|
||||
, SavedState(new detail::RuntimeDeviceTrackerInternals(this->Internals))
|
||||
{
|
||||
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region");
|
||||
std::copy_n(
|
||||
this->Internals->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, this->SavedState->RuntimeAllowed);
|
||||
}
|
||||
|
||||
VTKM_CONT
|
||||
ScopedRuntimeDeviceTracker::~ScopedRuntimeDeviceTracker()
|
||||
{
|
||||
VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Leaving scoped runtime region");
|
||||
std::copy_n(
|
||||
this->SavedState->RuntimeAllowed, VTKM_MAX_DEVICE_ADAPTER_ID, this->Internals->RuntimeAllowed);
|
||||
*(this->Internals) = this->SavedState.get();
|
||||
|
||||
this->LogEnabledDevices();
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,13 @@ public:
|
||||
///
|
||||
VTKM_CONT void ForceDevice(DeviceAdapterId deviceId);
|
||||
|
||||
/// \brief Copyies the state from the given device.
|
||||
/// \brief Get/Set use of thread-friendly memory allocation for a device.
|
||||
///
|
||||
///
|
||||
VTKM_CONT bool GetThreadFriendlyMemAlloc(DeviceAdapterId deviceId) const;
|
||||
VTKM_CONT void SetThreadFriendlyMemAlloc(vtkm::cont::DeviceAdapterId deviceId, bool state);
|
||||
|
||||
/// \brief Copies the state from the given device.
|
||||
///
|
||||
/// This is a convenient way to allow the `RuntimeDeviceTracker` on one thread
|
||||
/// copy the behavior from another thread.
|
||||
|
Loading…
Reference in New Issue
Block a user