Make ForEachValidDevice internal

* Add FindDeviceAdapterTagAndCall
* Add support for multiple arguments to be passed to the functor in
  'ForEachValidDevice' and 'FindDeviceAdapterTagAndCall'.
This commit is contained in:
Sujin Philip 2017-12-12 14:12:02 -05:00
parent 59dc78fd29
commit a4d0b57ba2
4 changed files with 149 additions and 59 deletions

@ -40,47 +40,6 @@ struct DeviceAdapterListTagCommon : vtkm::ListTagBase<vtkm::cont::DeviceAdapterT
vtkm::cont::DeviceAdapterTagSerial>
{
};
namespace detail
{
template <typename FunctorType>
class ExecuteIfValidDeviceTag
{
private:
template <typename DeviceAdapter>
using EnableIfValid = std::enable_if<vtkm::cont::DeviceAdapterTraits<DeviceAdapter>::Valid>;
template <typename DeviceAdapter>
using EnableIfInvalid = std::enable_if<!vtkm::cont::DeviceAdapterTraits<DeviceAdapter>::Valid>;
public:
explicit ExecuteIfValidDeviceTag(const FunctorType& functor)
: Functor(functor)
{
}
template <typename DeviceAdapter>
typename EnableIfValid<DeviceAdapter>::type operator()(DeviceAdapter) const
{
this->Functor(DeviceAdapter());
}
template <typename DeviceAdapter>
typename EnableIfInvalid<DeviceAdapter>::type operator()(DeviceAdapter) const
{
}
private:
FunctorType Functor;
};
} // detail
template <typename DeviceList, typename Functor>
VTKM_CONT void ForEachValidDevice(DeviceList devices, const Functor& functor)
{
vtkm::ListForEach(detail::ExecuteIfValidDeviceTag<Functor>(functor), devices);
}
}
} // namespace vtkm::cont

@ -23,6 +23,7 @@
#include <vtkm/cont/DeviceAdapterListTag.h>
#include <vtkm/cont/ErrorBadType.h>
#include <vtkm/cont/ErrorBadValue.h>
#include <vtkm/cont/internal/DeviceAdapterListHelpers.h>
#include <vtkm/cont/internal/DeviceAdapterTag.h>
#include <vtkm/cont/internal/VirtualObjectTransfer.h>
@ -101,9 +102,10 @@ public:
this->Internals->VirtualObject = derived;
this->Internals->Owner = acquireOwnership;
vtkm::cont::ForEachValidDevice(
devices,
CreateTransferInterface<VirtualDerivedType>(this->Internals->Transfers.data(), derived));
vtkm::cont::internal::ForEachValidDevice(devices,
CreateTransferInterface<VirtualDerivedType>(),
this->Internals->Transfers.data(),
derived);
}
}
@ -202,18 +204,12 @@ private:
};
template <typename VirtualDerivedType>
class CreateTransferInterface
struct CreateTransferInterface
{
public:
CreateTransferInterface(std::unique_ptr<TransferInterface>* transfers,
const VirtualDerivedType* virtualObject)
: Transfers(transfers)
, VirtualObject(virtualObject)
{
}
template <typename DeviceAdapter>
void operator()(DeviceAdapter) const
VTKM_CONT void operator()(DeviceAdapter,
std::unique_ptr<TransferInterface>* transfers,
const VirtualDerivedType* virtualObject) const
{
using DeviceInfo = vtkm::cont::DeviceAdapterTraits<DeviceAdapter>;
@ -225,12 +221,8 @@ private:
throw vtkm::cont::ErrorBadType(msg);
}
using TransferImpl = TransferInterfaceImpl<VirtualDerivedType, DeviceAdapter>;
this->Transfers[DeviceInfo::GetId()].reset(new TransferImpl(this->VirtualObject));
transfers[DeviceInfo::GetId()].reset(new TransferImpl(virtualObject));
}
private:
std::unique_ptr<TransferInterface>* Transfers;
const VirtualDerivedType* VirtualObject;
};
struct InternalStruct

@ -32,6 +32,7 @@ set(headers
DeviceAdapterAlgorithmGeneral.h
DeviceAdapterDefaultSelection.h
DeviceAdapterError.h
DeviceAdapterListHelpers.h
DeviceAdapterTag.h
DynamicTransform.h
FunctorsGeneral.h

@ -0,0 +1,138 @@
//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//
// Copyright 2016 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
// Copyright 2016 UT-Battelle, LLC.
// Copyright 2016 Los Alamos National Security.
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
// Laboratory (LANL), the U.S. Government retains certain rights in
// this software.
//============================================================================
#ifndef vtk_m_cont_internal_DeviceAdapterListHelpers_h
#define vtk_m_cont_internal_DeviceAdapterListHelpers_h
#include <vtkm/ListTag.h>
#include <vtkm/cont/ErrorBadDevice.h>
#include <vtkm/cont/RuntimeDeviceTracker.h>
namespace vtkm
{
namespace cont
{
namespace internal
{
//============================================================================
template <typename FunctorType>
class ExecuteIfValidDeviceTag
{
private:
template <typename DeviceAdapter>
using EnableIfValid = std::enable_if<vtkm::cont::DeviceAdapterTraits<DeviceAdapter>::Valid>;
template <typename DeviceAdapter>
using EnableIfInvalid = std::enable_if<!vtkm::cont::DeviceAdapterTraits<DeviceAdapter>::Valid>;
public:
explicit ExecuteIfValidDeviceTag(const FunctorType& functor)
: Functor(functor)
{
}
template <typename DeviceAdapter, typename... Args>
typename EnableIfValid<DeviceAdapter>::type operator()(
DeviceAdapter device,
const vtkm::cont::RuntimeDeviceTracker& tracker,
Args&&... args) const
{
if (tracker.CanRunOn(device))
{
this->Functor(device, std::forward<Args>(args)...);
}
}
// do not generate code for invalid devices
template <typename DeviceAdapter, typename... Args>
typename EnableIfInvalid<DeviceAdapter>::type operator()(DeviceAdapter,
const vtkm::cont::RuntimeDeviceTracker&,
Args&&...) const
{
}
private:
FunctorType Functor;
};
/// Execute the given functor on each valid device in \c DeviceList.
///
template <typename DeviceList, typename Functor, typename... Args>
VTKM_CONT void ForEachValidDevice(DeviceList devices, const Functor& functor, Args&&... args)
{
auto tracker = vtkm::cont::GetGlobalRuntimeDeviceTracker();
ExecuteIfValidDeviceTag<Functor> wrapped(functor);
vtkm::ListForEach(wrapped, devices, tracker, std::forward<Args>(args)...);
}
//============================================================================
template <typename FunctorType>
class ExecuteIfSameDeviceId
{
public:
ExecuteIfSameDeviceId(FunctorType functor)
: Functor(functor)
{
}
template <typename DeviceAdapter, typename... Args>
void operator()(DeviceAdapter device,
vtkm::cont::DeviceAdapterId deviceId,
bool& status,
Args&&... args) const
{
if (vtkm::cont::DeviceAdapterTraits<DeviceAdapter>::GetId() == deviceId)
{
VTKM_ASSERT(status == false);
this->Functor(device, std::forward<Args>(args)...);
status = true;
}
}
private:
FunctorType Functor;
};
/// Finds the \c DeviceAdapterTag in \c DeviceList with id equal to deviceId
/// and executes the functor with the tag. Throws \c ErrorBadDevice if a valid
/// \c DeviceAdapterTag is not found.
///
template <typename DeviceList, typename Functor, typename... Args>
VTKM_CONT void FindDeviceAdapterTagAndCall(vtkm::cont::DeviceAdapterId deviceId,
DeviceList devices,
const Functor& functor,
Args&&... args)
{
bool status = false;
ExecuteIfSameDeviceId<Functor> wrapped(functor);
ForEachValidDevice(devices, wrapped, deviceId, status, std::forward<Args>(args)...);
if (!status)
{
std::string msg =
"Device with id " + std::to_string(deviceId) + " is either not in the list or is invalid";
throw vtkm::cont::ErrorBadDevice(msg);
}
}
}
}
} // vtkm::cont::internal
#endif // vtk_m_cont_internal_DeviceAdapterListHelpers_h