mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-10-08 11:29:02 +00:00
Merge branch 'simplify_cuda_iterator_detection' into 'master'
Simplify cuda iterator detection See merge request !83
This commit is contained in:
commit
74377ecc92
@ -40,9 +40,6 @@ class ArrayPortalCounting
|
|||||||
public:
|
public:
|
||||||
typedef CountingValueType ValueType;
|
typedef CountingValueType ValueType;
|
||||||
|
|
||||||
typedef vtkm::cont::internal::IteratorFromArrayPortal<
|
|
||||||
ArrayPortalCounting < ValueType> > IteratorType;
|
|
||||||
|
|
||||||
VTKM_EXEC_CONT_EXPORT
|
VTKM_EXEC_CONT_EXPORT
|
||||||
ArrayPortalCounting() :
|
ArrayPortalCounting() :
|
||||||
StartingValue(),
|
StartingValue(),
|
||||||
|
@ -44,7 +44,6 @@ class ArrayPortalImplicit
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
typedef ValueType_ ValueType;
|
typedef ValueType_ ValueType;
|
||||||
typedef ValueType_ IteratorType;
|
|
||||||
typedef FunctorType_ FunctorType;
|
typedef FunctorType_ FunctorType;
|
||||||
|
|
||||||
VTKM_EXEC_CONT_EXPORT
|
VTKM_EXEC_CONT_EXPORT
|
||||||
|
@ -36,7 +36,6 @@ class ArrayPortalPermutationExec
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
typedef typename ValuePortalType::ValueType ValueType;
|
typedef typename ValuePortalType::ValueType ValueType;
|
||||||
typedef ValueType IteratorType;
|
|
||||||
|
|
||||||
VTKM_EXEC_CONT_EXPORT
|
VTKM_EXEC_CONT_EXPORT
|
||||||
ArrayPortalPermutationExec( )
|
ArrayPortalPermutationExec( )
|
||||||
@ -102,7 +101,6 @@ class ArrayPortalPermutationCont
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
typedef typename ValuePortalType::ValueType ValueType;
|
typedef typename ValuePortalType::ValueType ValueType;
|
||||||
typedef ValueType IteratorType;
|
|
||||||
|
|
||||||
VTKM_CONT_EXPORT
|
VTKM_CONT_EXPORT
|
||||||
ArrayPortalPermutationCont( )
|
ArrayPortalPermutationCont( )
|
||||||
|
@ -23,6 +23,7 @@
|
|||||||
#include <vtkm/Types.h>
|
#include <vtkm/Types.h>
|
||||||
#include <vtkm/Pair.h>
|
#include <vtkm/Pair.h>
|
||||||
#include <vtkm/internal/ExportMacros.h>
|
#include <vtkm/internal/ExportMacros.h>
|
||||||
|
#include <vtkm/cont/ArrayPortalToIterators.h>
|
||||||
|
|
||||||
#include <vtkm/exec/cuda/internal/ArrayPortalFromThrust.h>
|
#include <vtkm/exec/cuda/internal/ArrayPortalFromThrust.h>
|
||||||
#include <vtkm/exec/cuda/internal/WrappedOperators.h>
|
#include <vtkm/exec/cuda/internal/WrappedOperators.h>
|
||||||
@ -51,89 +52,70 @@ namespace internal {
|
|||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
// Tags to specify what type of thrust iterator to use.
|
// Tags to specify what type of thrust iterator to use.
|
||||||
struct ThrustIteratorTransformTag { };
|
struct ThrustIteratorFromArrayPortalTag { };
|
||||||
struct ThrustIteratorDevicePtrTag { };
|
struct ThrustIteratorDevicePtrTag { };
|
||||||
|
|
||||||
// Traits to help classify what thrust iterators will be used.
|
// Traits to help classify what thrust iterators will be used.
|
||||||
template<class PortalType, class IteratorType>
|
template<typename IteratorType>
|
||||||
struct ThrustIteratorTag {
|
struct ThrustIteratorTag {
|
||||||
typedef ThrustIteratorTransformTag Type;
|
typedef ThrustIteratorFromArrayPortalTag Type;
|
||||||
};
|
};
|
||||||
template<typename PortalType, typename T>
|
template<typename T>
|
||||||
struct ThrustIteratorTag<PortalType, T *> {
|
struct ThrustIteratorTag< thrust::system::cuda::pointer<T> > {
|
||||||
typedef ThrustIteratorDevicePtrTag Type;
|
typedef ThrustIteratorDevicePtrTag Type;
|
||||||
};
|
};
|
||||||
template<typename PortalType, typename T>
|
template<typename T>
|
||||||
struct ThrustIteratorTag<PortalType, const T*> {
|
struct ThrustIteratorTag< thrust::system::cuda::pointer<const T> > {
|
||||||
typedef ThrustIteratorDevicePtrTag Type;
|
typedef ThrustIteratorDevicePtrTag Type;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename T> struct ThrustStripPointer;
|
template<typename PortalType, typename Tag> struct IteratorChooser;
|
||||||
template<typename T> struct ThrustStripPointer<T *> {
|
template<typename PortalType>
|
||||||
typedef T Type;
|
struct IteratorChooser<PortalType, detail::ThrustIteratorFromArrayPortalTag> {
|
||||||
};
|
|
||||||
template<typename T> struct ThrustStripPointer<const T *> {
|
|
||||||
typedef const T Type;
|
|
||||||
};
|
|
||||||
|
|
||||||
template<class PortalType, class Tag> struct IteratorChooser;
|
|
||||||
template<class PortalType>
|
|
||||||
struct IteratorChooser<PortalType, detail::ThrustIteratorTransformTag> {
|
|
||||||
typedef vtkm::exec::cuda::internal::IteratorFromArrayPortal<PortalType> Type;
|
typedef vtkm::exec::cuda::internal::IteratorFromArrayPortal<PortalType> Type;
|
||||||
};
|
};
|
||||||
template<class PortalType>
|
template<typename PortalType>
|
||||||
struct IteratorChooser<PortalType, detail::ThrustIteratorDevicePtrTag> {
|
struct IteratorChooser<PortalType, detail::ThrustIteratorDevicePtrTag> {
|
||||||
typedef ::thrust::cuda::pointer<
|
typedef vtkm::cont::ArrayPortalToIterators<PortalType> PortalToIteratorType;
|
||||||
typename detail::ThrustStripPointer<
|
|
||||||
typename PortalType::IteratorType>::Type> Type;
|
typedef typename PortalToIteratorType::IteratorType Type;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template<class PortalType>
|
template<typename PortalType>
|
||||||
struct IteratorTraits
|
struct IteratorTraits
|
||||||
{
|
{
|
||||||
|
typedef vtkm::cont::ArrayPortalToIterators<PortalType> PortalToIteratorType;
|
||||||
typedef typename detail::ThrustIteratorTag<
|
typedef typename detail::ThrustIteratorTag<
|
||||||
PortalType,
|
typename PortalToIteratorType::IteratorType>::Type Tag;
|
||||||
typename PortalType::IteratorType>::Type Tag;
|
typedef typename IteratorChooser<
|
||||||
typedef typename IteratorChooser<PortalType, Tag>::Type IteratorType;
|
PortalType,
|
||||||
|
Tag
|
||||||
|
>::Type IteratorType;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename PortalType>
|
||||||
template<typename T>
|
|
||||||
VTKM_CONT_EXPORT
|
|
||||||
::thrust::cuda::pointer<T>
|
|
||||||
MakeDevicePtr(T *iter)
|
|
||||||
{
|
|
||||||
return::thrust::cuda::pointer<T>(iter);
|
|
||||||
}
|
|
||||||
template<typename T>
|
|
||||||
VTKM_CONT_EXPORT
|
|
||||||
::thrust::cuda::pointer<const T>
|
|
||||||
MakeDevicePtr(const T *iter)
|
|
||||||
{
|
|
||||||
return ::thrust::cuda::pointer<const T>(iter);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<class PortalType>
|
|
||||||
VTKM_CONT_EXPORT
|
VTKM_CONT_EXPORT
|
||||||
typename IteratorTraits<PortalType>::IteratorType
|
typename IteratorTraits<PortalType>::IteratorType
|
||||||
MakeIteratorBegin(PortalType portal, detail::ThrustIteratorTransformTag)
|
MakeIteratorBegin(PortalType portal, detail::ThrustIteratorFromArrayPortalTag)
|
||||||
{
|
{
|
||||||
return vtkm::exec::cuda::internal::IteratorFromArrayPortal<PortalType>(portal,0);
|
return vtkm::exec::cuda::internal::IteratorFromArrayPortal<PortalType>(portal);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class PortalType>
|
template<typename PortalType>
|
||||||
VTKM_CONT_EXPORT
|
VTKM_CONT_EXPORT
|
||||||
typename IteratorTraits<PortalType>::IteratorType
|
typename IteratorTraits<PortalType>::IteratorType
|
||||||
MakeIteratorBegin(PortalType portal, detail::ThrustIteratorDevicePtrTag)
|
MakeIteratorBegin(PortalType portal, detail::ThrustIteratorDevicePtrTag)
|
||||||
{
|
{
|
||||||
return MakeDevicePtr(portal.GetIteratorBegin());
|
vtkm::cont::ArrayPortalToIterators<PortalType> iterators(portal);
|
||||||
|
return iterators.GetBegin();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template<class PortalType>
|
template<typename PortalType>
|
||||||
VTKM_CONT_EXPORT
|
VTKM_CONT_EXPORT
|
||||||
typename detail::IteratorTraits<PortalType>::IteratorType
|
typename detail::IteratorTraits<PortalType>::IteratorType
|
||||||
IteratorBegin(PortalType portal)
|
IteratorBegin(PortalType portal)
|
||||||
@ -142,7 +124,7 @@ IteratorBegin(PortalType portal)
|
|||||||
return detail::MakeIteratorBegin(portal, IteratorTag());
|
return detail::MakeIteratorBegin(portal, IteratorTag());
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class PortalType>
|
template<typename PortalType>
|
||||||
VTKM_CONT_EXPORT
|
VTKM_CONT_EXPORT
|
||||||
typename detail::IteratorTraits<PortalType>::IteratorType
|
typename detail::IteratorTraits<PortalType>::IteratorType
|
||||||
IteratorEnd(PortalType portal)
|
IteratorEnd(PortalType portal)
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
#define vtk_m_exec_cuda_internal_ArrayPortalFromThrust_h
|
#define vtk_m_exec_cuda_internal_ArrayPortalFromThrust_h
|
||||||
|
|
||||||
#include <vtkm/Types.h>
|
#include <vtkm/Types.h>
|
||||||
|
#include <vtkm/cont/ArrayPortalToIterators.h>
|
||||||
|
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <boost/type_traits/remove_const.hpp>
|
#include <boost/type_traits/remove_const.hpp>
|
||||||
@ -239,13 +240,13 @@ class ArrayPortalFromThrust : public ArrayPortalFromThrustBase
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
typedef T ValueType;
|
typedef T ValueType;
|
||||||
typedef typename thrust::system::cuda::pointer< T > PointerType;
|
typedef thrust::system::cuda::pointer< T > IteratorType;
|
||||||
typedef T* IteratorType;
|
|
||||||
|
|
||||||
VTKM_EXEC_CONT_EXPORT ArrayPortalFromThrust() { }
|
VTKM_EXEC_CONT_EXPORT ArrayPortalFromThrust() { }
|
||||||
|
|
||||||
VTKM_CONT_EXPORT
|
VTKM_CONT_EXPORT
|
||||||
ArrayPortalFromThrust(PointerType begin, PointerType end)
|
ArrayPortalFromThrust(thrust::system::cuda::pointer< T > begin,
|
||||||
|
thrust::system::cuda::pointer< T > end)
|
||||||
: BeginIterator( begin ),
|
: BeginIterator( begin ),
|
||||||
EndIterator( end )
|
EndIterator( end )
|
||||||
{ }
|
{ }
|
||||||
@ -269,29 +270,23 @@ public:
|
|||||||
|
|
||||||
VTKM_EXEC_EXPORT
|
VTKM_EXEC_EXPORT
|
||||||
ValueType Get(vtkm::Id index) const {
|
ValueType Get(vtkm::Id index) const {
|
||||||
return *this->IteratorAt(index);
|
return *(this->BeginIterator + index);
|
||||||
}
|
}
|
||||||
|
|
||||||
VTKM_EXEC_EXPORT
|
VTKM_EXEC_EXPORT
|
||||||
void Set(vtkm::Id index, ValueType value) const {
|
void Set(vtkm::Id index, ValueType value) const {
|
||||||
*this->IteratorAt(index) = value;
|
*(this->BeginIterator + index) = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
VTKM_EXEC_CONT_EXPORT
|
VTKM_EXEC_CONT_EXPORT
|
||||||
IteratorType GetIteratorBegin() const { return this->BeginIterator.get(); }
|
IteratorType GetIteratorBegin() const { return this->BeginIterator; }
|
||||||
|
|
||||||
VTKM_EXEC_CONT_EXPORT
|
VTKM_EXEC_CONT_EXPORT
|
||||||
IteratorType GetIteratorEnd() const { return this->EndIterator.get(); }
|
IteratorType GetIteratorEnd() const { return this->EndIterator; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
PointerType BeginIterator;
|
IteratorType BeginIterator;
|
||||||
PointerType EndIterator;
|
IteratorType EndIterator;
|
||||||
|
|
||||||
VTKM_EXEC_EXPORT
|
|
||||||
PointerType IteratorAt(vtkm::Id index) const {
|
|
||||||
// Not using std::advance because on CUDA it cannot be used on a device.
|
|
||||||
return (this->BeginIterator + index);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@ -300,13 +295,13 @@ class ConstArrayPortalFromThrust : public ArrayPortalFromThrustBase
|
|||||||
public:
|
public:
|
||||||
|
|
||||||
typedef T ValueType;
|
typedef T ValueType;
|
||||||
typedef typename thrust::system::cuda::pointer< T > PointerType;
|
typedef thrust::system::cuda::pointer< const T > IteratorType;
|
||||||
typedef const T* IteratorType;
|
|
||||||
|
|
||||||
VTKM_EXEC_CONT_EXPORT ConstArrayPortalFromThrust() { }
|
VTKM_EXEC_CONT_EXPORT ConstArrayPortalFromThrust() { }
|
||||||
|
|
||||||
VTKM_CONT_EXPORT
|
VTKM_CONT_EXPORT
|
||||||
ConstArrayPortalFromThrust(const PointerType begin, const PointerType end)
|
ConstArrayPortalFromThrust(const thrust::system::cuda::pointer< T > begin,
|
||||||
|
const thrust::system::cuda::pointer< T > end)
|
||||||
: BeginIterator( begin ),
|
: BeginIterator( begin ),
|
||||||
EndIterator( end )
|
EndIterator( end )
|
||||||
{
|
{
|
||||||
@ -333,29 +328,23 @@ public:
|
|||||||
|
|
||||||
VTKM_EXEC_EXPORT
|
VTKM_EXEC_EXPORT
|
||||||
ValueType Get(vtkm::Id index) const {
|
ValueType Get(vtkm::Id index) const {
|
||||||
return vtkm::exec::cuda::internal::load_through_texture<ValueType>::get( this->IteratorAt(index) );
|
return vtkm::exec::cuda::internal::load_through_texture<ValueType>::get( this->BeginIterator + index );
|
||||||
}
|
}
|
||||||
|
|
||||||
VTKM_EXEC_EXPORT
|
VTKM_EXEC_EXPORT
|
||||||
void Set(vtkm::Id index, ValueType value) const {
|
void Set(vtkm::Id index, ValueType value) const {
|
||||||
*this->IteratorAt(index) = value;
|
*(this->BeginIterator + index) = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
VTKM_EXEC_CONT_EXPORT
|
VTKM_EXEC_CONT_EXPORT
|
||||||
IteratorType GetIteratorBegin() const { return this->BeginIterator.get(); }
|
IteratorType GetIteratorBegin() const { return this->BeginIterator; }
|
||||||
|
|
||||||
VTKM_EXEC_CONT_EXPORT
|
VTKM_EXEC_CONT_EXPORT
|
||||||
IteratorType GetIteratorEnd() const { return this->EndIterator.get(); }
|
IteratorType GetIteratorEnd() const { return this->EndIterator; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
PointerType BeginIterator;
|
IteratorType BeginIterator;
|
||||||
PointerType EndIterator;
|
IteratorType EndIterator;
|
||||||
|
|
||||||
VTKM_EXEC_EXPORT
|
|
||||||
PointerType IteratorAt(vtkm::Id index) const {
|
|
||||||
// Not using std::advance because on CUDA it cannot be used on a device.
|
|
||||||
return (this->BeginIterator + index);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -364,4 +353,75 @@ private:
|
|||||||
} // namespace vtkm::exec::cuda::internal
|
} // namespace vtkm::exec::cuda::internal
|
||||||
|
|
||||||
|
|
||||||
|
namespace vtkm {
|
||||||
|
namespace cont {
|
||||||
|
|
||||||
|
/// Partial specialization of \c ArrayPortalToIterators for \c
|
||||||
|
/// ArrayPortalFromThrust. Returns the original array rather than
|
||||||
|
/// the portal wrapped in an \c IteratorFromArrayPortal.
|
||||||
|
///
|
||||||
|
template<typename T>
|
||||||
|
class ArrayPortalToIterators<
|
||||||
|
vtkm::exec::cuda::internal::ArrayPortalFromThrust<T> >
|
||||||
|
{
|
||||||
|
typedef vtkm::exec::cuda::internal::ArrayPortalFromThrust<T>
|
||||||
|
PortalType;
|
||||||
|
public:
|
||||||
|
|
||||||
|
typedef typename PortalType::IteratorType IteratorType;
|
||||||
|
|
||||||
|
VTKM_CONT_EXPORT
|
||||||
|
ArrayPortalToIterators(const PortalType &portal)
|
||||||
|
: BIterator(portal.GetIteratorBegin()),
|
||||||
|
EIterator(portal.GetIteratorEnd())
|
||||||
|
{ }
|
||||||
|
|
||||||
|
VTKM_CONT_EXPORT
|
||||||
|
IteratorType GetBegin() const { return this->BIterator; }
|
||||||
|
|
||||||
|
VTKM_CONT_EXPORT
|
||||||
|
IteratorType GetEnd() const { return this->EIterator; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
IteratorType BIterator;
|
||||||
|
IteratorType EIterator;
|
||||||
|
vtkm::Id NumberOfValues;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Partial specialization of \c ArrayPortalToIterators for \c
|
||||||
|
/// ConstArrayPortalFromThrust. Returns the original array rather than
|
||||||
|
/// the portal wrapped in an \c IteratorFromArrayPortal.
|
||||||
|
///
|
||||||
|
template<typename T>
|
||||||
|
class ArrayPortalToIterators<
|
||||||
|
vtkm::exec::cuda::internal::ConstArrayPortalFromThrust<T> >
|
||||||
|
{
|
||||||
|
typedef vtkm::exec::cuda::internal::ConstArrayPortalFromThrust<T>
|
||||||
|
PortalType;
|
||||||
|
public:
|
||||||
|
|
||||||
|
typedef typename PortalType::IteratorType IteratorType;
|
||||||
|
|
||||||
|
VTKM_CONT_EXPORT
|
||||||
|
ArrayPortalToIterators(const PortalType &portal)
|
||||||
|
: BIterator(portal.GetIteratorBegin()),
|
||||||
|
EIterator(portal.GetIteratorEnd())
|
||||||
|
{ }
|
||||||
|
|
||||||
|
VTKM_CONT_EXPORT
|
||||||
|
IteratorType GetBegin() const { return this->BIterator; }
|
||||||
|
|
||||||
|
VTKM_CONT_EXPORT
|
||||||
|
IteratorType GetEnd() const { return this->EIterator; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
IteratorType BIterator;
|
||||||
|
IteratorType EIterator;
|
||||||
|
vtkm::Id NumberOfValues;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
} // namespace vtkm::cont
|
||||||
|
|
||||||
|
|
||||||
#endif //vtk_m_exec_cuda_internal_ArrayPortalFromThrust_h
|
#endif //vtk_m_exec_cuda_internal_ArrayPortalFromThrust_h
|
||||||
|
Loading…
Reference in New Issue
Block a user