diy: pass operator instance to mpi_op<>::get()

This makes it possible to add custom MPI reduction operations without
having to add too complex logic to allocating and freeing MPI operation.
This commit is contained in:
Utkarsh Ayachit 2017-12-12 15:36:21 -05:00
parent 42d5be319b
commit c63f3635d5
2 changed files with 16 additions and 16 deletions

@ -152,13 +152,13 @@ namespace mpi
}
}
static void reduce(const communicator& comm, const T& in, T& out, int root, const Op&)
static void reduce(const communicator& comm, const T& in, T& out, int root, const Op& op)
{
MPI_Reduce(Datatype::address(const_cast<T&>(in)),
Datatype::address(out),
Datatype::count(in),
Datatype::datatype(),
detail::mpi_op<Op>::get(),
detail::mpi_op<Op>::get(op),
root, comm);
}
@ -168,38 +168,38 @@ namespace mpi
Datatype::address(const_cast<T&>(in)),
Datatype::count(in),
Datatype::datatype(),
detail::mpi_op<Op>::get(),
detail::mpi_op<Op>::get(op),
root, comm);
}
static void all_reduce(const communicator& comm, const T& in, T& out, const Op&)
static void all_reduce(const communicator& comm, const T& in, T& out, const Op& op)
{
MPI_Allreduce(Datatype::address(const_cast<T&>(in)),
Datatype::address(out),
Datatype::count(in),
Datatype::datatype(),
detail::mpi_op<Op>::get(),
detail::mpi_op<Op>::get(op),
comm);
}
static void all_reduce(const communicator& comm, const std::vector<T>& in, std::vector<T>& out, const Op&)
static void all_reduce(const communicator& comm, const std::vector<T>& in, std::vector<T>& out, const Op& op)
{
out.resize(in.size());
MPI_Allreduce(Datatype::address(const_cast<T&>(in[0])),
Datatype::address(out[0]),
in.size(),
Datatype::datatype(),
detail::mpi_op<Op>::get(),
detail::mpi_op<Op>::get(op),
comm);
}
static void scan(const communicator& comm, const T& in, T& out, const Op&)
static void scan(const communicator& comm, const T& in, T& out, const Op& op)
{
MPI_Scan(Datatype::address(const_cast<T&>(in)),
Datatype::address(out),
Datatype::count(in),
Datatype::datatype(),
detail::mpi_op<Op>::get(),
detail::mpi_op<Op>::get(op),
comm);
}

@ -14,13 +14,13 @@ namespace mpi
namespace detail
{
template<class T> struct mpi_op { static MPI_Op get(); };
template<class U> struct mpi_op< maximum<U> > { static MPI_Op get() { return MPI_MAX; } };
template<class U> struct mpi_op< minimum<U> > { static MPI_Op get() { return MPI_MIN; } };
template<class U> struct mpi_op< std::plus<U> > { static MPI_Op get() { return MPI_SUM; } };
template<class U> struct mpi_op< std::multiplies<U> > { static MPI_Op get() { return MPI_PROD; } };
template<class U> struct mpi_op< std::logical_and<U> > { static MPI_Op get() { return MPI_LAND; } };
template<class U> struct mpi_op< std::logical_or<U> > { static MPI_Op get() { return MPI_LOR; } };
template<class T> struct mpi_op { static MPI_Op get(const T&); };
template<class U> struct mpi_op< maximum<U> > { static MPI_Op get(const maximum<U>&) { return MPI_MAX; } };
template<class U> struct mpi_op< minimum<U> > { static MPI_Op get(const minimum<U>&) { return MPI_MIN; } };
template<class U> struct mpi_op< std::plus<U> > { static MPI_Op get(const std::plus<U>&) { return MPI_SUM; } };
template<class U> struct mpi_op< std::multiplies<U> > { static MPI_Op get(const std::multiplies<U>&) { return MPI_PROD; } };
template<class U> struct mpi_op< std::logical_and<U> > { static MPI_Op get(const std::logical_and<U>&) { return MPI_LAND; } };
template<class U> struct mpi_op< std::logical_or<U> > { static MPI_Op get(const std::logical_or<U>&) { return MPI_LOR; } };
}
}
}