Do not assume CUDA reduce operator is unary

The `Reduce` algorithm is sometimes used to convert an input type to a
different output type. For example, you can compute the min and max at
the same time by making the output of the binary functor a pair of the
input type. However, for this to work with the CUDA algorithm, you have
to be able to also convert the input type to the output type. This was
previously done by treating the binary operator as also a unary
operator. That's fine for custom operators, but if you are using
something like `thrust::plus`, it has no unary operation. (Why would
it?)

So, detect whether the operator has a unary operation. If it does, use
it to cast from the input portal to the output type. If it does not,
just use `static_cast`. Thus, the operator only has to have the unary
operation if `static_cast` does not work.
This commit is contained in:
Kenneth Moreland 2021-03-02 16:55:42 -07:00
parent f3a6931f6b
commit a7100c845a

@ -178,8 +178,25 @@ __global__ void SumExclusiveScan(T a, T b, T result, BinaryOperationType binary_
#pragma GCC diagnostic pop
#endif
template <typename FunctorType, typename ArgType>
struct FunctorSupportsUnaryImpl
{
template <typename F, typename A, typename = decltype(std::declval<F>()(std::declval<A>()))>
static std::true_type has(int);
template <typename F, typename A>
static std::false_type has(...);
using type = decltype(has<FunctorType, ArgType>(0));
};
template <typename FunctorType, typename ArgType>
using FunctorSupportsUnary = typename FunctorSupportsUnaryImpl<FunctorType, ArgType>::type;
template <typename PortalType,
typename BinaryAndUnaryFunctor,
typename = FunctorSupportsUnary<BinaryAndUnaryFunctor, typename PortalType::ValueType>>
struct CastPortal;
template <typename PortalType, typename BinaryAndUnaryFunctor>
struct CastPortal
struct CastPortal<PortalType, BinaryAndUnaryFunctor, std::true_type>
{
using InputType = typename PortalType::ValueType;
using ValueType = decltype(std::declval<BinaryAndUnaryFunctor>()(std::declval<InputType>()));
@ -201,6 +218,28 @@ struct CastPortal
ValueType Get(vtkm::Id index) const { return this->Functor(this->Portal.Get(index)); }
};
template <typename PortalType, typename BinaryFunctor>
struct CastPortal<PortalType, BinaryFunctor, std::false_type>
{
using InputType = typename PortalType::ValueType;
using ValueType =
decltype(std::declval<BinaryFunctor>()(std::declval<InputType>(), std::declval<InputType>()));
PortalType Portal;
VTKM_CONT
CastPortal(const PortalType& portal, const BinaryFunctor&)
: Portal(portal)
{
}
VTKM_EXEC
vtkm::Id GetNumberOfValues() const { return this->Portal.GetNumberOfValues(); }
VTKM_EXEC
ValueType Get(vtkm::Id index) const { return static_cast<ValueType>(this->Portal.Get(index)); }
};
struct CudaFreeFunctor
{
void operator()(void* ptr) const { VTKM_CUDA_CALL(cudaFree(ptr)); }