vtk-m2/vtkm/cont/cuda/internal/MakeThrustIterator.h

325 lines
9.5 KiB
C++

//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//
// Copyright 2014 Sandia Corporation.
// Copyright 2014 UT-Battelle, LLC.
// Copyright 2014 Los Alamos National Security.
//
// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
// the U.S. Government retains certain rights in this software.
//
// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
// Laboratory (LANL), the U.S. Government retains certain rights in
// this software.
//============================================================================
#ifndef vtk_m_cont_cuda_internal_MakeThrustIterator_h
#define vtk_m_cont_cuda_internal_MakeThrustIterator_h
#include <vtkm/Types.h>
#include <vtkm/Pair.h>
#include <vtkm/internal/ExportMacros.h>
#include <vtkm/exec/cuda/internal/ArrayPortalFromThrust.h>
// Disable warnings we check vtkm for but Thrust does not.
#if defined(__GNUC__) || defined(____clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wconversion"
#endif // gcc || clang
#include <thrust/system/cuda/memory.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#if defined(__GNUC__) || defined(____clang__)
#pragma GCC diagnostic pop
#endif // gcc || clang
//needed forward declares
namespace vtkm {
namespace exec {
namespace internal {
template<typename ValueType_,
typename PortalTypeFirst,
typename PortalTypeSecond>
class ArrayPortalExecZip;
}
}
}
namespace vtkm {
namespace cont {
namespace cuda {
namespace internal {
namespace detail {
// Tags to specify what type of thrust iterator to use.
struct ThrustIteratorTransformTag { };
struct ThrustIteratorZipTag { };
struct ThrustIteratorDevicePtrTag { };
// Traits to help classify what thrust iterators will be used.
template<class PortalType, class IteratorType>
struct ThrustIteratorTag {
typedef ThrustIteratorTransformTag Type;
};
template<typename PortalType, typename T>
struct ThrustIteratorTag<PortalType, T *> {
typedef ThrustIteratorDevicePtrTag Type;
};
template<typename PortalType, typename T>
struct ThrustIteratorTag<PortalType, const T*> {
typedef ThrustIteratorDevicePtrTag Type;
};
template<typename T, typename U, typename V>
struct ThrustIteratorTag< vtkm::exec::internal::ArrayPortalExecZip< T, U, V >,
T > {
//this is a real special case. ExecZip and PortalValue don't combine
//well together, when used with DeviceAlgorithm that has a custom operator
//the custom operator is actually passed the PortalValue instead of
//the real values, and by that point we can't fix anything since we
//don't know what the original operator is
typedef ThrustIteratorZipTag Type;
};
template<typename T> struct ThrustStripPointer;
template<typename T> struct ThrustStripPointer<T *> {
typedef T Type;
};
template<typename T> struct ThrustStripPointer<const T *> {
typedef const T Type;
};
template<class PortalType>
struct PortalValue {
typedef typename PortalType::ValueType ValueType;
VTKM_EXEC_EXPORT
PortalValue()
: Portal(),
Index(0) { }
VTKM_EXEC_EXPORT
PortalValue(const PortalType &portal, vtkm::Id index)
: Portal(portal), Index(index) { }
VTKM_EXEC_EXPORT
PortalValue(const PortalValue<PortalType> &other)
: Portal(other.Portal), Index(other.Index) { }
VTKM_EXEC_EXPORT
ValueType operator=(ValueType value) {
this->Portal.Set(this->Index, value);
return value;
}
VTKM_EXEC_EXPORT
operator ValueType(void) const {
return this->Portal.Get(this->Index);
}
const PortalType Portal;
const vtkm::Id Index;
};
template<class PortalType>
class LookupFunctor
: public ::thrust::unary_function<vtkm::Id,
PortalValue<PortalType> >
{
public:
VTKM_EXEC_EXPORT LookupFunctor()
: Portal() { }
VTKM_EXEC_EXPORT LookupFunctor(PortalType portal)
: Portal(portal) { }
VTKM_EXEC_EXPORT
PortalValue<PortalType>
operator()(vtkm::Id index)
{
return PortalValue<PortalType>(this->Portal, index);
}
private:
PortalType Portal;
};
template<class PortalType, class Tag> struct IteratorChooser;
template<class PortalType>
struct IteratorChooser<PortalType, detail::ThrustIteratorTransformTag> {
typedef ::thrust::transform_iterator<
LookupFunctor<PortalType>,
::thrust::counting_iterator<vtkm::Id> > Type;
};
template<class PortalType>
struct IteratorChooser<PortalType, detail::ThrustIteratorZipTag> {
//this is a real special case. ExecZip and PortalValue don't combine
//well together, when used with DeviceAlgorithm that has a custom operator
//the custom operator is actually passed the PortalValue instead of
//the real values, and by that point we can't fix anything since we
//don't know what the original operator is.
//So to fix this issue we wrap the original array portals into a thrust
//zip iterator and let handle everything
typedef typename PortalType::PortalTypeFirst PortalTypeFirst;
typedef typename IteratorTraits<PortalTypeFirst>::IteratorType FirstIterType;
typedef typename PortalType::PortalTypeSecond PortalTypeSecond;
typedef typename IteratorTraits<PortalTypeSecond>::IteratorType SecondIterType;
//Now that we have deduced the concrete types of the first and second
//array portals of the zip we can construct a zip iterator for those
typedef ::thrust::tuple<FirstIterType, SecondIterType> IteratorTuple;
typedef ::thrust::zip_iterator<IteratorTuple> Type;
};
template<class PortalType>
struct IteratorChooser<PortalType, detail::ThrustIteratorDevicePtrTag> {
typedef ::thrust::cuda::pointer<
typename detail::ThrustStripPointer<
typename PortalType::IteratorType>::Type> Type;
};
template<class PortalType>
struct IteratorTraits
{
typedef typename detail::ThrustIteratorTag<
PortalType,
typename PortalType::IteratorType>::Type Tag;
typedef typename IteratorChooser<PortalType, Tag>::Type IteratorType;
};
template<typename T>
VTKM_CONT_EXPORT static
::thrust::cuda::pointer<T>
MakeDevicePtr(T *iter)
{
return::thrust::cuda::pointer<T>(iter);
}
template<typename T>
VTKM_CONT_EXPORT static
::thrust::cuda::pointer<const T>
MakeDevicePtr(const T *iter)
{
return ::thrust::cuda::pointer<const T>(iter);
}
template<typename T, typename U>
VTKM_CONT_EXPORT static
::thrust::zip_iterator<T,U>
MakeZipIterator(const T t, const U u)
{
//todo deduce from T and U the iterator types
this is what needs finished
return ::thrust::make_zip_iterator(
::thrust::make_tuple( IteratorBegin(t),
IteratorBegin(u) )
);
}
template<class PortalType>
VTKM_CONT_EXPORT static
typename IteratorTraits<PortalType>::IteratorType
MakeIteratorBegin(PortalType portal, detail::ThrustIteratorTransformTag)
{
return ::thrust::make_transform_iterator(
::thrust::make_counting_iterator(vtkm::Id(0)),
LookupFunctor<PortalType>(portal));
}
template<class PortalType>
VTKM_CONT_EXPORT static
typename IteratorTraits<PortalType>::IteratorType
MakeIteratorBegin(PortalType portal, detail::ThrustIteratorZipTag)
{
return MakeZipIterator(portal.GetFirstPortal(),
portal.GetSecondPortal()
);
}
template<class PortalType>
VTKM_CONT_EXPORT static
typename IteratorTraits<PortalType>::IteratorType
MakeIteratorBegin(PortalType portal, detail::ThrustIteratorDevicePtrTag)
{
return MakeDevicePtr(portal.GetIteratorBegin());
}
} // namespace detail
template<class PortalType>
VTKM_CONT_EXPORT
typename detail::IteratorTraits<PortalType>::IteratorType
IteratorBegin(PortalType portal)
{
typedef typename detail::IteratorTraits<PortalType>::Tag IteratorTag;
return detail::MakeIteratorBegin(portal, IteratorTag());
}
template<class PortalType>
VTKM_CONT_EXPORT
typename detail::IteratorTraits<PortalType>::IteratorType
IteratorEnd(PortalType portal)
{
return IteratorBegin(portal) + portal.GetNumberOfValues();
}
}
}
}
} //namespace vtkm::cont::cuda::internal
namespace thrust {
template< typename PortalType >
struct less< vtkm::cont::cuda::internal::detail::PortalValue< PortalType > > :
public binary_function<
vtkm::cont::cuda::internal::detail::PortalValue< PortalType >,
vtkm::cont::cuda::internal::detail::PortalValue< PortalType >,
bool>
{
typedef vtkm::cont::cuda::internal::detail::PortalValue< PortalType > T;
typedef typename vtkm::cont::cuda::internal::detail::PortalValue<
PortalType >::ValueType ValueType;
/*! Function call operator. The return value is <tt>lhs < rhs</tt>.
*/
__host__ __device__ bool operator()(const T &lhs, const T &rhs) const
{return (ValueType)lhs < (ValueType)rhs;}
/*! Function call operator. The return value is <tt>lhs < rhs</tt>.
specially designed to work with vtkm portal values, which can
be compared to their underline type
*/
__host__ __device__ bool operator()(const T &lhs,
const ValueType &rhs) const
{return (ValueType)lhs < rhs;}
}; // end less
}
#endif