//============================================================================ // Copyright (c) Kitware, Inc. // All rights reserved. // See LICENSE.txt for details. // // This software is distributed WITHOUT ANY WARRANTY; without even // the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // PURPOSE. See the above copyright notice for more information. //============================================================================ #ifndef vtk_m_filter_Histogram_hxx #define vtk_m_filter_Histogram_hxx #include #include #include #include #include #include #include #include #include namespace vtkm { namespace filter { namespace detail { class DistributedHistogram { class Reducer { public: void operator()(vtkm::cont::ArrayHandle* result, const vtkmdiy::ReduceProxy& srp, const vtkmdiy::RegularMergePartners&) const { const auto selfid = srp.gid(); // 1. dequeue. std::vector incoming; srp.incoming(incoming); for (const int gid : incoming) { if (gid != selfid) { vtkm::cont::ArrayHandle in; srp.dequeue(gid, in); if (result->GetNumberOfValues() == 0) { *result = in; } else { vtkm::cont::Algorithm::Transform(*result, in, *result, vtkm::Add()); } } } // 2. enqueue for (int cc = 0; cc < srp.out_link().size(); ++cc) { auto target = srp.out_link().target(cc); if (target.gid != selfid) { srp.enqueue(target, *result); } } } }; std::vector> LocalBlocks; public: DistributedHistogram(vtkm::Id numLocalBlocks) : LocalBlocks(static_cast(numLocalBlocks)) { } void SetLocalHistogram(vtkm::Id index, const vtkm::cont::ArrayHandle& bins) { this->LocalBlocks[static_cast(index)] = bins; } void SetLocalHistogram(vtkm::Id index, const vtkm::cont::Field& field) { this->SetLocalHistogram(index, field.GetData().Cast>()); } vtkm::cont::ArrayHandle ReduceAll() const { using ArrayType = vtkm::cont::ArrayHandle; const vtkm::Id numLocalBlocks = static_cast(this->LocalBlocks.size()); auto comm = vtkm::cont::EnvironmentTracker::GetCommunicator(); if (comm.size() == 1 && numLocalBlocks <= 1) { // no reduction necessary. return numLocalBlocks == 0 ? ArrayType() : this->LocalBlocks[0]; } vtkmdiy::Master master( comm, /*threads*/ 1, /*limit*/ -1, []() -> void* { return new vtkm::cont::ArrayHandle(); }, [](void* ptr) { delete static_cast*>(ptr); }); vtkm::cont::AssignerPartitionedDataSet assigner(numLocalBlocks); vtkmdiy::RegularDecomposer decomposer( /*dims*/ 1, vtkmdiy::interval(0, assigner.nblocks() - 1), assigner.nblocks()); decomposer.decompose(comm.rank(), assigner, master); assert(static_cast(master.size()) == numLocalBlocks); for (vtkm::Id cc = 0; cc < numLocalBlocks; ++cc) { *master.block(static_cast(cc)) = this->LocalBlocks[static_cast(cc)]; } vtkmdiy::RegularMergePartners partners(decomposer, /*k=*/2); // reduce to block-0. vtkmdiy::reduce(master, assigner, partners, Reducer()); ArrayType result; if (master.local(0)) { result = *master.block(master.lid(0)); } this->Broadcast(result); return result; } private: void Broadcast(vtkm::cont::ArrayHandle& data) const { // broadcast to all ranks (and not blocks). auto comm = vtkm::cont::EnvironmentTracker::GetCommunicator(); if (comm.size() > 1) { using ArrayType = vtkm::cont::ArrayHandle; vtkmdiy::Master master( comm, /*threads*/ 1, /*limit*/ -1, []() -> void* { return new vtkm::cont::ArrayHandle(); }, [](void* ptr) { delete static_cast*>(ptr); }); vtkmdiy::ContiguousAssigner assigner(comm.size(), comm.size()); vtkmdiy::RegularDecomposer decomposer( 1, vtkmdiy::interval(0, comm.size() - 1), comm.size()); decomposer.decompose(comm.rank(), assigner, master); assert(master.size() == 1); // number of local blocks should be 1 per rank. *master.block(0) = data; vtkmdiy::RegularBroadcastPartners partners(decomposer, /*k=*/2); vtkmdiy::reduce(master, assigner, partners, Reducer()); data = *master.block(0); } } }; } // namespace detail //----------------------------------------------------------------------------- inline VTKM_CONT Histogram::Histogram() : NumberOfBins(10) , BinDelta(0) , ComputedRange() , Range() { this->SetOutputFieldName("histogram"); } //----------------------------------------------------------------------------- template inline VTKM_CONT vtkm::cont::DataSet Histogram::DoExecute( const vtkm::cont::DataSet&, const vtkm::cont::ArrayHandle& field, const vtkm::filter::FieldMetadata&, vtkm::filter::PolicyBase) { vtkm::cont::ArrayHandle binArray; T delta; vtkm::worklet::FieldHistogram worklet; if (this->ComputedRange.IsNonEmpty()) { worklet.Run(field, this->NumberOfBins, static_cast(this->ComputedRange.Min), static_cast(this->ComputedRange.Max), delta, binArray); } else { worklet.Run(field, this->NumberOfBins, this->ComputedRange, delta, binArray); } this->BinDelta = static_cast(delta); vtkm::cont::DataSet output; vtkm::cont::Field rfield( this->GetOutputFieldName(), vtkm::cont::Field::Association::WHOLE_MESH, binArray); output.AddField(rfield); return output; } //----------------------------------------------------------------------------- template inline VTKM_CONT void Histogram::PreExecute(const vtkm::cont::PartitionedDataSet& input, const vtkm::filter::PolicyBase&) { // Policies are on their way out, but until they are we want to respect them. In the mean // time, respect the policy if it is defined. using TypeList = typename std::conditional< std::is_same::value, VTKM_DEFAULT_TYPE_LIST, typename DerivedPolicy::FieldTypeList>::type; if (this->Range.IsNonEmpty()) { this->ComputedRange = this->Range; } else { auto handle = vtkm::cont::FieldRangeGlobalCompute( input, this->GetActiveFieldName(), this->GetActiveFieldAssociation(), TypeList()); if (handle.GetNumberOfValues() != 1) { throw vtkm::cont::ErrorFilterExecution("expecting scalar field."); } this->ComputedRange = handle.ReadPortal().Get(0); } } //----------------------------------------------------------------------------- template inline VTKM_CONT void Histogram::PostExecute(const vtkm::cont::PartitionedDataSet&, vtkm::cont::PartitionedDataSet& result, const vtkm::filter::PolicyBase&) { // iterate and compute histogram for each local block. detail::DistributedHistogram helper(result.GetNumberOfPartitions()); for (vtkm::Id cc = 0; cc < result.GetNumberOfPartitions(); ++cc) { auto& ablock = result.GetPartition(cc); helper.SetLocalHistogram(cc, ablock.GetField(this->GetOutputFieldName())); } vtkm::cont::DataSet output; vtkm::cont::Field rfield( this->GetOutputFieldName(), vtkm::cont::Field::Association::WHOLE_MESH, helper.ReduceAll()); output.AddField(rfield); result = vtkm::cont::PartitionedDataSet(output); } } } // namespace vtkm::filter #endif