diff --git a/vtkm/cont/DeviceAdapterListTag.h b/vtkm/cont/DeviceAdapterListTag.h index d38bb09ca..a41d57c24 100644 --- a/vtkm/cont/DeviceAdapterListTag.h +++ b/vtkm/cont/DeviceAdapterListTag.h @@ -40,47 +40,6 @@ struct DeviceAdapterListTagCommon : vtkm::ListTagBase { }; - -namespace detail -{ - -template -class ExecuteIfValidDeviceTag -{ -private: - template - using EnableIfValid = std::enable_if::Valid>; - - template - using EnableIfInvalid = std::enable_if::Valid>; - -public: - explicit ExecuteIfValidDeviceTag(const FunctorType& functor) - : Functor(functor) - { - } - - template - typename EnableIfValid::type operator()(DeviceAdapter) const - { - this->Functor(DeviceAdapter()); - } - - template - typename EnableIfInvalid::type operator()(DeviceAdapter) const - { - } - -private: - FunctorType Functor; -}; -} // detail - -template -VTKM_CONT void ForEachValidDevice(DeviceList devices, const Functor& functor) -{ - vtkm::ListForEach(detail::ExecuteIfValidDeviceTag(functor), devices); -} } } // namespace vtkm::cont diff --git a/vtkm/cont/VirtualObjectHandle.h b/vtkm/cont/VirtualObjectHandle.h index c372d3095..97fad110d 100644 --- a/vtkm/cont/VirtualObjectHandle.h +++ b/vtkm/cont/VirtualObjectHandle.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -101,9 +102,10 @@ public: this->Internals->VirtualObject = derived; this->Internals->Owner = acquireOwnership; - vtkm::cont::ForEachValidDevice( - devices, - CreateTransferInterface(this->Internals->Transfers.data(), derived)); + vtkm::cont::internal::ForEachValidDevice(devices, + CreateTransferInterface(), + this->Internals->Transfers.data(), + derived); } } @@ -202,18 +204,12 @@ private: }; template - class CreateTransferInterface + struct CreateTransferInterface { - public: - CreateTransferInterface(std::unique_ptr* transfers, - const VirtualDerivedType* virtualObject) - : Transfers(transfers) - , VirtualObject(virtualObject) - { - } - template - void operator()(DeviceAdapter) const + VTKM_CONT void operator()(DeviceAdapter, + std::unique_ptr* transfers, + const VirtualDerivedType* virtualObject) const { using DeviceInfo = vtkm::cont::DeviceAdapterTraits; @@ -225,12 +221,8 @@ private: throw vtkm::cont::ErrorBadType(msg); } using TransferImpl = TransferInterfaceImpl; - this->Transfers[DeviceInfo::GetId()].reset(new TransferImpl(this->VirtualObject)); + transfers[DeviceInfo::GetId()].reset(new TransferImpl(virtualObject)); } - - private: - std::unique_ptr* Transfers; - const VirtualDerivedType* VirtualObject; }; struct InternalStruct diff --git a/vtkm/cont/internal/CMakeLists.txt b/vtkm/cont/internal/CMakeLists.txt index 96942b85f..2faf93a7e 100644 --- a/vtkm/cont/internal/CMakeLists.txt +++ b/vtkm/cont/internal/CMakeLists.txt @@ -32,6 +32,7 @@ set(headers DeviceAdapterAlgorithmGeneral.h DeviceAdapterDefaultSelection.h DeviceAdapterError.h + DeviceAdapterListHelpers.h DeviceAdapterTag.h DynamicTransform.h FunctorsGeneral.h diff --git a/vtkm/cont/internal/DeviceAdapterListHelpers.h b/vtkm/cont/internal/DeviceAdapterListHelpers.h new file mode 100644 index 000000000..8897338fb --- /dev/null +++ b/vtkm/cont/internal/DeviceAdapterListHelpers.h @@ -0,0 +1,138 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +// +// Copyright 2016 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +// Copyright 2016 UT-Battelle, LLC. +// Copyright 2016 Los Alamos National Security. +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National +// Laboratory (LANL), the U.S. Government retains certain rights in +// this software. +//============================================================================ +#ifndef vtk_m_cont_internal_DeviceAdapterListHelpers_h +#define vtk_m_cont_internal_DeviceAdapterListHelpers_h + +#include +#include +#include + +namespace vtkm +{ +namespace cont +{ +namespace internal +{ + +//============================================================================ +template +class ExecuteIfValidDeviceTag +{ +private: + template + using EnableIfValid = std::enable_if::Valid>; + + template + using EnableIfInvalid = std::enable_if::Valid>; + +public: + explicit ExecuteIfValidDeviceTag(const FunctorType& functor) + : Functor(functor) + { + } + + template + typename EnableIfValid::type operator()( + DeviceAdapter device, + const vtkm::cont::RuntimeDeviceTracker& tracker, + Args&&... args) const + { + if (tracker.CanRunOn(device)) + { + this->Functor(device, std::forward(args)...); + } + } + + // do not generate code for invalid devices + template + typename EnableIfInvalid::type operator()(DeviceAdapter, + const vtkm::cont::RuntimeDeviceTracker&, + Args&&...) const + { + } + +private: + FunctorType Functor; +}; + +/// Execute the given functor on each valid device in \c DeviceList. +/// +template +VTKM_CONT void ForEachValidDevice(DeviceList devices, const Functor& functor, Args&&... args) +{ + auto tracker = vtkm::cont::GetGlobalRuntimeDeviceTracker(); + + ExecuteIfValidDeviceTag wrapped(functor); + vtkm::ListForEach(wrapped, devices, tracker, std::forward(args)...); +} + +//============================================================================ +template +class ExecuteIfSameDeviceId +{ +public: + ExecuteIfSameDeviceId(FunctorType functor) + : Functor(functor) + { + } + + template + void operator()(DeviceAdapter device, + vtkm::cont::DeviceAdapterId deviceId, + bool& status, + Args&&... args) const + { + if (vtkm::cont::DeviceAdapterTraits::GetId() == deviceId) + { + VTKM_ASSERT(status == false); + this->Functor(device, std::forward(args)...); + status = true; + } + } + +private: + FunctorType Functor; +}; + +/// Finds the \c DeviceAdapterTag in \c DeviceList with id equal to deviceId +/// and executes the functor with the tag. Throws \c ErrorBadDevice if a valid +/// \c DeviceAdapterTag is not found. +/// +template +VTKM_CONT void FindDeviceAdapterTagAndCall(vtkm::cont::DeviceAdapterId deviceId, + DeviceList devices, + const Functor& functor, + Args&&... args) +{ + bool status = false; + ExecuteIfSameDeviceId wrapped(functor); + ForEachValidDevice(devices, wrapped, deviceId, status, std::forward(args)...); + if (!status) + { + std::string msg = + "Device with id " + std::to_string(deviceId) + " is either not in the list or is invalid"; + throw vtkm::cont::ErrorBadDevice(msg); + } +} +} +} +} // vtkm::cont::internal + +#endif // vtk_m_cont_internal_DeviceAdapterListHelpers_h