mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-10-08 11:29:02 +00:00
Merge branch 'simplify_cuda_iterator_detection' into 'master'
Simplify cuda iterator detection See merge request !83
This commit is contained in:
commit
74377ecc92
@ -40,9 +40,6 @@ class ArrayPortalCounting
|
||||
public:
|
||||
typedef CountingValueType ValueType;
|
||||
|
||||
typedef vtkm::cont::internal::IteratorFromArrayPortal<
|
||||
ArrayPortalCounting < ValueType> > IteratorType;
|
||||
|
||||
VTKM_EXEC_CONT_EXPORT
|
||||
ArrayPortalCounting() :
|
||||
StartingValue(),
|
||||
|
@ -44,7 +44,6 @@ class ArrayPortalImplicit
|
||||
{
|
||||
public:
|
||||
typedef ValueType_ ValueType;
|
||||
typedef ValueType_ IteratorType;
|
||||
typedef FunctorType_ FunctorType;
|
||||
|
||||
VTKM_EXEC_CONT_EXPORT
|
||||
|
@ -36,7 +36,6 @@ class ArrayPortalPermutationExec
|
||||
{
|
||||
public:
|
||||
typedef typename ValuePortalType::ValueType ValueType;
|
||||
typedef ValueType IteratorType;
|
||||
|
||||
VTKM_EXEC_CONT_EXPORT
|
||||
ArrayPortalPermutationExec( )
|
||||
@ -102,7 +101,6 @@ class ArrayPortalPermutationCont
|
||||
{
|
||||
public:
|
||||
typedef typename ValuePortalType::ValueType ValueType;
|
||||
typedef ValueType IteratorType;
|
||||
|
||||
VTKM_CONT_EXPORT
|
||||
ArrayPortalPermutationCont( )
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include <vtkm/Types.h>
|
||||
#include <vtkm/Pair.h>
|
||||
#include <vtkm/internal/ExportMacros.h>
|
||||
#include <vtkm/cont/ArrayPortalToIterators.h>
|
||||
|
||||
#include <vtkm/exec/cuda/internal/ArrayPortalFromThrust.h>
|
||||
#include <vtkm/exec/cuda/internal/WrappedOperators.h>
|
||||
@ -51,89 +52,70 @@ namespace internal {
|
||||
namespace detail {
|
||||
|
||||
// Tags to specify what type of thrust iterator to use.
|
||||
struct ThrustIteratorTransformTag { };
|
||||
struct ThrustIteratorFromArrayPortalTag { };
|
||||
struct ThrustIteratorDevicePtrTag { };
|
||||
|
||||
// Traits to help classify what thrust iterators will be used.
|
||||
template<class PortalType, class IteratorType>
|
||||
template<typename IteratorType>
|
||||
struct ThrustIteratorTag {
|
||||
typedef ThrustIteratorTransformTag Type;
|
||||
typedef ThrustIteratorFromArrayPortalTag Type;
|
||||
};
|
||||
template<typename PortalType, typename T>
|
||||
struct ThrustIteratorTag<PortalType, T *> {
|
||||
template<typename T>
|
||||
struct ThrustIteratorTag< thrust::system::cuda::pointer<T> > {
|
||||
typedef ThrustIteratorDevicePtrTag Type;
|
||||
};
|
||||
template<typename PortalType, typename T>
|
||||
struct ThrustIteratorTag<PortalType, const T*> {
|
||||
template<typename T>
|
||||
struct ThrustIteratorTag< thrust::system::cuda::pointer<const T> > {
|
||||
typedef ThrustIteratorDevicePtrTag Type;
|
||||
};
|
||||
|
||||
template<typename T> struct ThrustStripPointer;
|
||||
template<typename T> struct ThrustStripPointer<T *> {
|
||||
typedef T Type;
|
||||
};
|
||||
template<typename T> struct ThrustStripPointer<const T *> {
|
||||
typedef const T Type;
|
||||
};
|
||||
|
||||
template<class PortalType, class Tag> struct IteratorChooser;
|
||||
template<class PortalType>
|
||||
struct IteratorChooser<PortalType, detail::ThrustIteratorTransformTag> {
|
||||
template<typename PortalType, typename Tag> struct IteratorChooser;
|
||||
template<typename PortalType>
|
||||
struct IteratorChooser<PortalType, detail::ThrustIteratorFromArrayPortalTag> {
|
||||
typedef vtkm::exec::cuda::internal::IteratorFromArrayPortal<PortalType> Type;
|
||||
};
|
||||
template<class PortalType>
|
||||
template<typename PortalType>
|
||||
struct IteratorChooser<PortalType, detail::ThrustIteratorDevicePtrTag> {
|
||||
typedef ::thrust::cuda::pointer<
|
||||
typename detail::ThrustStripPointer<
|
||||
typename PortalType::IteratorType>::Type> Type;
|
||||
typedef vtkm::cont::ArrayPortalToIterators<PortalType> PortalToIteratorType;
|
||||
|
||||
typedef typename PortalToIteratorType::IteratorType Type;
|
||||
|
||||
};
|
||||
|
||||
template<class PortalType>
|
||||
template<typename PortalType>
|
||||
struct IteratorTraits
|
||||
{
|
||||
typedef vtkm::cont::ArrayPortalToIterators<PortalType> PortalToIteratorType;
|
||||
typedef typename detail::ThrustIteratorTag<
|
||||
typename PortalToIteratorType::IteratorType>::Type Tag;
|
||||
typedef typename IteratorChooser<
|
||||
PortalType,
|
||||
typename PortalType::IteratorType>::Type Tag;
|
||||
typedef typename IteratorChooser<PortalType, Tag>::Type IteratorType;
|
||||
Tag
|
||||
>::Type IteratorType;
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
VTKM_CONT_EXPORT
|
||||
::thrust::cuda::pointer<T>
|
||||
MakeDevicePtr(T *iter)
|
||||
{
|
||||
return::thrust::cuda::pointer<T>(iter);
|
||||
}
|
||||
template<typename T>
|
||||
VTKM_CONT_EXPORT
|
||||
::thrust::cuda::pointer<const T>
|
||||
MakeDevicePtr(const T *iter)
|
||||
{
|
||||
return ::thrust::cuda::pointer<const T>(iter);
|
||||
}
|
||||
|
||||
template<class PortalType>
|
||||
template<typename PortalType>
|
||||
VTKM_CONT_EXPORT
|
||||
typename IteratorTraits<PortalType>::IteratorType
|
||||
MakeIteratorBegin(PortalType portal, detail::ThrustIteratorTransformTag)
|
||||
MakeIteratorBegin(PortalType portal, detail::ThrustIteratorFromArrayPortalTag)
|
||||
{
|
||||
return vtkm::exec::cuda::internal::IteratorFromArrayPortal<PortalType>(portal,0);
|
||||
return vtkm::exec::cuda::internal::IteratorFromArrayPortal<PortalType>(portal);
|
||||
}
|
||||
|
||||
template<class PortalType>
|
||||
template<typename PortalType>
|
||||
VTKM_CONT_EXPORT
|
||||
typename IteratorTraits<PortalType>::IteratorType
|
||||
MakeIteratorBegin(PortalType portal, detail::ThrustIteratorDevicePtrTag)
|
||||
{
|
||||
return MakeDevicePtr(portal.GetIteratorBegin());
|
||||
vtkm::cont::ArrayPortalToIterators<PortalType> iterators(portal);
|
||||
return iterators.GetBegin();
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
|
||||
|
||||
template<class PortalType>
|
||||
template<typename PortalType>
|
||||
VTKM_CONT_EXPORT
|
||||
typename detail::IteratorTraits<PortalType>::IteratorType
|
||||
IteratorBegin(PortalType portal)
|
||||
@ -142,7 +124,7 @@ IteratorBegin(PortalType portal)
|
||||
return detail::MakeIteratorBegin(portal, IteratorTag());
|
||||
}
|
||||
|
||||
template<class PortalType>
|
||||
template<typename PortalType>
|
||||
VTKM_CONT_EXPORT
|
||||
typename detail::IteratorTraits<PortalType>::IteratorType
|
||||
IteratorEnd(PortalType portal)
|
||||
|
@ -21,6 +21,7 @@
|
||||
#define vtk_m_exec_cuda_internal_ArrayPortalFromThrust_h
|
||||
|
||||
#include <vtkm/Types.h>
|
||||
#include <vtkm/cont/ArrayPortalToIterators.h>
|
||||
|
||||
#include <iterator>
|
||||
#include <boost/type_traits/remove_const.hpp>
|
||||
@ -239,13 +240,13 @@ class ArrayPortalFromThrust : public ArrayPortalFromThrustBase
|
||||
{
|
||||
public:
|
||||
typedef T ValueType;
|
||||
typedef typename thrust::system::cuda::pointer< T > PointerType;
|
||||
typedef T* IteratorType;
|
||||
typedef thrust::system::cuda::pointer< T > IteratorType;
|
||||
|
||||
VTKM_EXEC_CONT_EXPORT ArrayPortalFromThrust() { }
|
||||
|
||||
VTKM_CONT_EXPORT
|
||||
ArrayPortalFromThrust(PointerType begin, PointerType end)
|
||||
ArrayPortalFromThrust(thrust::system::cuda::pointer< T > begin,
|
||||
thrust::system::cuda::pointer< T > end)
|
||||
: BeginIterator( begin ),
|
||||
EndIterator( end )
|
||||
{ }
|
||||
@ -269,29 +270,23 @@ public:
|
||||
|
||||
VTKM_EXEC_EXPORT
|
||||
ValueType Get(vtkm::Id index) const {
|
||||
return *this->IteratorAt(index);
|
||||
return *(this->BeginIterator + index);
|
||||
}
|
||||
|
||||
VTKM_EXEC_EXPORT
|
||||
void Set(vtkm::Id index, ValueType value) const {
|
||||
*this->IteratorAt(index) = value;
|
||||
*(this->BeginIterator + index) = value;
|
||||
}
|
||||
|
||||
VTKM_EXEC_CONT_EXPORT
|
||||
IteratorType GetIteratorBegin() const { return this->BeginIterator.get(); }
|
||||
IteratorType GetIteratorBegin() const { return this->BeginIterator; }
|
||||
|
||||
VTKM_EXEC_CONT_EXPORT
|
||||
IteratorType GetIteratorEnd() const { return this->EndIterator.get(); }
|
||||
IteratorType GetIteratorEnd() const { return this->EndIterator; }
|
||||
|
||||
private:
|
||||
PointerType BeginIterator;
|
||||
PointerType EndIterator;
|
||||
|
||||
VTKM_EXEC_EXPORT
|
||||
PointerType IteratorAt(vtkm::Id index) const {
|
||||
// Not using std::advance because on CUDA it cannot be used on a device.
|
||||
return (this->BeginIterator + index);
|
||||
}
|
||||
IteratorType BeginIterator;
|
||||
IteratorType EndIterator;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
@ -300,13 +295,13 @@ class ConstArrayPortalFromThrust : public ArrayPortalFromThrustBase
|
||||
public:
|
||||
|
||||
typedef T ValueType;
|
||||
typedef typename thrust::system::cuda::pointer< T > PointerType;
|
||||
typedef const T* IteratorType;
|
||||
typedef thrust::system::cuda::pointer< const T > IteratorType;
|
||||
|
||||
VTKM_EXEC_CONT_EXPORT ConstArrayPortalFromThrust() { }
|
||||
|
||||
VTKM_CONT_EXPORT
|
||||
ConstArrayPortalFromThrust(const PointerType begin, const PointerType end)
|
||||
ConstArrayPortalFromThrust(const thrust::system::cuda::pointer< T > begin,
|
||||
const thrust::system::cuda::pointer< T > end)
|
||||
: BeginIterator( begin ),
|
||||
EndIterator( end )
|
||||
{
|
||||
@ -333,29 +328,23 @@ public:
|
||||
|
||||
VTKM_EXEC_EXPORT
|
||||
ValueType Get(vtkm::Id index) const {
|
||||
return vtkm::exec::cuda::internal::load_through_texture<ValueType>::get( this->IteratorAt(index) );
|
||||
return vtkm::exec::cuda::internal::load_through_texture<ValueType>::get( this->BeginIterator + index );
|
||||
}
|
||||
|
||||
VTKM_EXEC_EXPORT
|
||||
void Set(vtkm::Id index, ValueType value) const {
|
||||
*this->IteratorAt(index) = value;
|
||||
*(this->BeginIterator + index) = value;
|
||||
}
|
||||
|
||||
VTKM_EXEC_CONT_EXPORT
|
||||
IteratorType GetIteratorBegin() const { return this->BeginIterator.get(); }
|
||||
IteratorType GetIteratorBegin() const { return this->BeginIterator; }
|
||||
|
||||
VTKM_EXEC_CONT_EXPORT
|
||||
IteratorType GetIteratorEnd() const { return this->EndIterator.get(); }
|
||||
IteratorType GetIteratorEnd() const { return this->EndIterator; }
|
||||
|
||||
private:
|
||||
PointerType BeginIterator;
|
||||
PointerType EndIterator;
|
||||
|
||||
VTKM_EXEC_EXPORT
|
||||
PointerType IteratorAt(vtkm::Id index) const {
|
||||
// Not using std::advance because on CUDA it cannot be used on a device.
|
||||
return (this->BeginIterator + index);
|
||||
}
|
||||
IteratorType BeginIterator;
|
||||
IteratorType EndIterator;
|
||||
};
|
||||
|
||||
}
|
||||
@ -364,4 +353,75 @@ private:
|
||||
} // namespace vtkm::exec::cuda::internal
|
||||
|
||||
|
||||
namespace vtkm {
|
||||
namespace cont {
|
||||
|
||||
/// Partial specialization of \c ArrayPortalToIterators for \c
|
||||
/// ArrayPortalFromThrust. Returns the original array rather than
|
||||
/// the portal wrapped in an \c IteratorFromArrayPortal.
|
||||
///
|
||||
template<typename T>
|
||||
class ArrayPortalToIterators<
|
||||
vtkm::exec::cuda::internal::ArrayPortalFromThrust<T> >
|
||||
{
|
||||
typedef vtkm::exec::cuda::internal::ArrayPortalFromThrust<T>
|
||||
PortalType;
|
||||
public:
|
||||
|
||||
typedef typename PortalType::IteratorType IteratorType;
|
||||
|
||||
VTKM_CONT_EXPORT
|
||||
ArrayPortalToIterators(const PortalType &portal)
|
||||
: BIterator(portal.GetIteratorBegin()),
|
||||
EIterator(portal.GetIteratorEnd())
|
||||
{ }
|
||||
|
||||
VTKM_CONT_EXPORT
|
||||
IteratorType GetBegin() const { return this->BIterator; }
|
||||
|
||||
VTKM_CONT_EXPORT
|
||||
IteratorType GetEnd() const { return this->EIterator; }
|
||||
|
||||
private:
|
||||
IteratorType BIterator;
|
||||
IteratorType EIterator;
|
||||
vtkm::Id NumberOfValues;
|
||||
};
|
||||
|
||||
/// Partial specialization of \c ArrayPortalToIterators for \c
|
||||
/// ConstArrayPortalFromThrust. Returns the original array rather than
|
||||
/// the portal wrapped in an \c IteratorFromArrayPortal.
|
||||
///
|
||||
template<typename T>
|
||||
class ArrayPortalToIterators<
|
||||
vtkm::exec::cuda::internal::ConstArrayPortalFromThrust<T> >
|
||||
{
|
||||
typedef vtkm::exec::cuda::internal::ConstArrayPortalFromThrust<T>
|
||||
PortalType;
|
||||
public:
|
||||
|
||||
typedef typename PortalType::IteratorType IteratorType;
|
||||
|
||||
VTKM_CONT_EXPORT
|
||||
ArrayPortalToIterators(const PortalType &portal)
|
||||
: BIterator(portal.GetIteratorBegin()),
|
||||
EIterator(portal.GetIteratorEnd())
|
||||
{ }
|
||||
|
||||
VTKM_CONT_EXPORT
|
||||
IteratorType GetBegin() const { return this->BIterator; }
|
||||
|
||||
VTKM_CONT_EXPORT
|
||||
IteratorType GetEnd() const { return this->EIterator; }
|
||||
|
||||
private:
|
||||
IteratorType BIterator;
|
||||
IteratorType EIterator;
|
||||
vtkm::Id NumberOfValues;
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace vtkm::cont
|
||||
|
||||
|
||||
#endif //vtk_m_exec_cuda_internal_ArrayPortalFromThrust_h
|
||||
|
Loading…
Reference in New Issue
Block a user