Compile ArrayGetValues implementation in library

Previously, all of the `ArrayGetValues` implementations were templated
functions that had to be built. That meant that any code using them had
to be compiled with a device compiler and create special code for it.
This change uses an `UnknownArrayHandle` to encapsulate the
`ArrayHandle` and call a per-compiled library function. This means that
the code only has to be compiled once.
This commit is contained in:
Kenneth Moreland 2021-08-05 15:58:34 -06:00
parent bfb693c1d2
commit e1ac918bc7
5 changed files with 136 additions and 54 deletions

@ -0,0 +1,20 @@
# Compile `ArrayGetValues` implementation in a library
Previously, all of the `ArrayGetValue` implementations were templated
functions that had to be built by all code that used it. That had 2
negative consequences.
1. The same code that scheduled jobs on any device had to be compiled many
times over.
2. Any code that used `ArrayGetValue` had to be compiled with a device
compiler. If you had non-worklet code that just wanted to get a single
value out of an array, that was a pain.
To get around this problem, an `ArrayGetValues` function that takes
`UnknownArrayHandle`s was created. The implementation for this function is
compiled into a library. It uses `UnknownArrayHandle`'s ability to extract
a component of the array with a uniform type to reduce the number of code
paths it generates. Although there are still several code paths, they only
have to be computed once. Plus, now any code can include `ArrayGetValues.h`
and still use a basic C++ compiler.

@ -0,0 +1,80 @@
//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//============================================================================
#include <vtkm/cont/ArrayGetValues.h>
#include <vtkm/cont/Algorithm.h>
#include <vtkm/cont/ArrayHandlePermutation.h>
#include <vtkm/cont/UnknownArrayHandle.h>
#include <vtkm/List.h>
#include <vtkm/TypeList.h>
void vtkm::cont::internal::ArrayGetValuesImpl(const vtkm::cont::UnknownArrayHandle& ids,
const vtkm::cont::UnknownArrayHandle& data,
const vtkm::cont::UnknownArrayHandle& output)
{
auto idArray = ids.ExtractComponent<vtkm::Id>(0, vtkm::CopyFlag::On);
output.Allocate(ids.GetNumberOfValues());
bool copied = false;
vtkm::ListForEach(
[&](auto base) {
using T = decltype(base);
if (!copied && data.IsBaseComponentType<T>())
{
vtkm::IdComponent numComponents = data.GetNumberOfComponentsFlat();
VTKM_ASSERT(output.GetNumberOfComponentsFlat() == numComponents);
for (vtkm::IdComponent componentIdx = 0; componentIdx < numComponents; ++componentIdx)
{
auto dataArray = data.ExtractComponent<T>(componentIdx, vtkm::CopyFlag::On);
auto outputArray = output.ExtractComponent<T>(componentIdx, vtkm::CopyFlag::Off);
auto permutedArray = vtkm::cont::make_ArrayHandlePermutation(idArray, dataArray);
bool copiedComponent = false;
if (!dataArray.IsOnHost())
{
copiedComponent = vtkm::cont::TryExecute([&](auto device) {
if (dataArray.IsOnDevice(device))
{
vtkm::cont::DeviceAdapterAlgorithm<decltype(device)>::Copy(permutedArray,
outputArray);
return true;
}
return false;
});
}
if (!copiedComponent)
{ // Fallback to a control-side copy if the device copy fails or if the device
// is undefined or if the data were already on the host. In this case, the
// best we can do is grab the portals and copy one at a time on the host with
// a for loop.
const vtkm::Id numVals = ids.GetNumberOfValues();
auto inPortal = permutedArray.ReadPortal();
auto outPortal = outputArray.WritePortal();
for (vtkm::Id i = 0; i < numVals; ++i)
{
outPortal.Set(i, inPortal.Get(i));
}
}
}
copied = true;
}
},
vtkm::TypeListBaseC{});
if (!copied)
{
throw vtkm::cont::ErrorBadType("Unable to get values from array of type " +
data.GetArrayTypeName());
}
}

@ -10,13 +10,10 @@
#ifndef vtk_m_cont_ArrayGetValues_h
#define vtk_m_cont_ArrayGetValues_h
#include <vtkm/cont/Algorithm.h>
#include <vtkm/cont/vtkm_cont_export.h>
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandlePermutation.h>
#include <vtkm/cont/ArrayPortalToIterators.h>
#include <vtkm/cont/DeviceAdapterTag.h>
#include <vtkm/cont/ErrorExecution.h>
#include <vtkm/cont/Logging.h>
#include <vtkm/cont/UnknownArrayHandle.h>
#include <initializer_list>
#include <vector>
@ -26,32 +23,17 @@ namespace vtkm
namespace cont
{
namespace detail
// Work around circular dependancy with UnknownArrayHandle.
class UnknownArrayHandle;
namespace internal
{
struct ArrayGetValuesFunctor
{
template <typename Device, typename IdsArray, typename DataArray, typename OutputArray>
VTKM_CONT bool operator()(Device,
const IdsArray& ids,
const DataArray& data,
OutputArray& output) const
{
// Only get data on a device the data are already on.
if (data.IsOnDevice(Device{}))
{
const auto input = vtkm::cont::make_ArrayHandlePermutation(ids, data);
vtkm::cont::DeviceAdapterAlgorithm<Device>::Copy(input, output);
return true;
}
else
{
return false;
}
}
};
VTKM_CONT_EXPORT void ArrayGetValuesImpl(const vtkm::cont::UnknownArrayHandle& ids,
const vtkm::cont::UnknownArrayHandle& data,
const vtkm::cont::UnknownArrayHandle& output);
} // namespace detail
} // namespace internal
/// \brief Obtain a small set of values from an ArrayHandle with minimal device
/// transfers.
@ -115,29 +97,7 @@ VTKM_CONT void ArrayGetValues(const vtkm::cont::ArrayHandle<vtkm::Id, SIds>& ids
const vtkm::cont::ArrayHandle<T, SData>& data,
vtkm::cont::ArrayHandle<T, SOut>& output)
{
bool copyComplete = false;
// If the data are not already on the host, attempt to copy on the device.
if (!data.IsOnHost())
{
copyComplete = vtkm::cont::TryExecute(detail::ArrayGetValuesFunctor{}, ids, data, output);
}
if (!copyComplete)
{ // Fallback to a control-side copy if the device copy fails or if the device
// is undefined or if the data were already on the host. In this case, the
// best we can do is grab the portals and copy one at a time on the host with
// a for loop.
const vtkm::Id numVals = ids.GetNumberOfValues();
auto idPortal = ids.ReadPortal();
auto dataPortal = data.ReadPortal();
output.Allocate(numVals);
auto outPortal = output.WritePortal();
for (vtkm::Id i = 0; i < numVals; ++i)
{
outPortal.Set(i, dataPortal.Get(idPortal.Get(i)));
}
}
internal::ArrayGetValuesImpl(ids, data, output);
}
template <typename SIds, typename T, typename SData, typename Alloc>

@ -161,6 +161,7 @@ set(sources
# compiled with a device-specific compiler (like CUDA).
set(device_sources
ArrayCopy.cxx
ArrayGetValues.cxx
ArrayRangeCompute.cxx
AssignerPartitionedDataSet.cxx
BoundsCompute.cxx

@ -716,7 +716,15 @@ void* Buffer::GetMetaData(const std::string& type) const
bool Buffer::IsAllocatedOnHost() const
{
LockType lock = this->Internals->GetLock();
return this->Internals->GetHostBuffer(lock).UpToDate;
if (this->Internals->GetNumberOfBytes(lock) > 0)
{
return this->Internals->GetHostBuffer(lock).UpToDate;
}
else
{
// Nothing allocated. Say the data exists everywhere.
return true;
}
}
bool Buffer::IsAllocatedOnDevice(vtkm::cont::DeviceAdapterId device) const
@ -724,7 +732,15 @@ bool Buffer::IsAllocatedOnDevice(vtkm::cont::DeviceAdapterId device) const
if (device.IsValueValid())
{
LockType lock = this->Internals->GetLock();
return this->Internals->GetDeviceBuffers(lock)[device].UpToDate;
if (this->Internals->GetNumberOfBytes(lock) > 0)
{
return this->Internals->GetDeviceBuffers(lock)[device].UpToDate;
}
else
{
// Nothing allocated. Say the data exists everywhere.
return true;
}
}
else if (device == vtkm::cont::DeviceAdapterTagUndefined{})
{
@ -735,6 +751,11 @@ bool Buffer::IsAllocatedOnDevice(vtkm::cont::DeviceAdapterId device) const
{
// Return if allocated on any device.
LockType lock = this->Internals->GetLock();
if (this->Internals->GetNumberOfBytes(lock) <= 0)
{
// Nothing allocated. Say the data exists everywhere.
return true;
}
for (auto&& deviceBuffer : this->Internals->GetDeviceBuffers(lock))
{
if (deviceBuffer.second.UpToDate)