Fix issue with Managed Memory for 0 size arrays

This commit is contained in:
Sujin Philip 2017-12-19 17:18:24 -05:00
parent a6f6ea99a4
commit b530a5ce3f

@ -33,6 +33,10 @@ static bool IsInitialized = false;
// True if all devices support concurrent pagable managed memory.
static bool ManagedMemorySupported = false;
// Avoid overhead of cudaMemAdvise and cudaMemPrefetchAsync for small buffers.
// This value should be > 0 or else these functions will error out.
static std::size_t Threshold = 1 << 20;
}
namespace vtkm
@ -94,6 +98,12 @@ bool CudaAllocator::IsManagedPointer(const void* ptr)
void* CudaAllocator::Allocate(std::size_t numBytes)
{
CudaAllocator::Initialize();
// When numBytes is zero cudaMallocManaged returns an error and the behavior
// of cudaMalloc is not documented. Just return nullptr.
if (numBytes == 0)
{
return nullptr;
}
void* ptr = nullptr;
if (ManagedMemorySupported)
@ -115,7 +125,7 @@ void CudaAllocator::Free(void* ptr)
void CudaAllocator::PrepareForControl(const void* ptr, std::size_t numBytes)
{
if (IsManagedPointer(ptr))
if (IsManagedPointer(ptr) && numBytes >= Threshold)
{
#if CUDART_VERSION >= 8000
// TODO these hints need to be benchmarked and adjusted once we start
@ -128,7 +138,7 @@ void CudaAllocator::PrepareForControl(const void* ptr, std::size_t numBytes)
void CudaAllocator::PrepareForInput(const void* ptr, std::size_t numBytes)
{
if (IsManagedPointer(ptr))
if (IsManagedPointer(ptr) && numBytes >= Threshold)
{
#if CUDART_VERSION >= 8000
int dev;
@ -143,7 +153,7 @@ void CudaAllocator::PrepareForInput(const void* ptr, std::size_t numBytes)
void CudaAllocator::PrepareForOutput(const void* ptr, std::size_t numBytes)
{
if (IsManagedPointer(ptr))
if (IsManagedPointer(ptr) && numBytes >= Threshold)
{
#if CUDART_VERSION >= 8000
int dev;
@ -158,7 +168,7 @@ void CudaAllocator::PrepareForOutput(const void* ptr, std::size_t numBytes)
void CudaAllocator::PrepareForInPlace(const void* ptr, std::size_t numBytes)
{
if (IsManagedPointer(ptr))
if (IsManagedPointer(ptr) && numBytes >= Threshold)
{
#if CUDART_VERSION >= 8000
int dev;